1、基础操作
package com.journey.sql;
import com.alibaba.fastjson.JSON;
import com.journey.sql.bean.User;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.apache.spark.sql.functions.col;
public class SparkSQLTest {
public static void main(String[] args) throws Exception {
SparkSession spark = SparkSession
.builder()
.appName("Demo")
.master("local[*]")
.getOrCreate();
// 读取 json 文件 创建 DataFrame {"username": "lisi","age": 18},DataFrame是一种特殊的Dataset,行是Row
Dataset<Row> df = spark.read().json("datas/sql/user.json");
// 展示表结构 + 数据
df.show();
// 打印schema结构
df.printSchema();
// 直接select
df.select("username").show();
// 加1
df.select(col("username"), col("age").plus(1)).show();
// 过滤age大于19
df.filter(col("age").gt(19)).show();
// 统计age的个数
df.groupBy("age").count().show();
df.createOrReplaceTempView("user");
// 使用sql来查询
Dataset<Row> sqlDF = spark.sql("select * from user");
sqlDF.show();
// 注册DataFrame作为一个全局的临时视图
df.createGlobalTempView("user2");
spark.sql("select * from global_temp.user2").show();
spark.newSession().sql("select * from global_temp.user2").show();
/**
* 数据集与 RDD 类似,但是,它们不使用 Java 序列化或 Kryo,而是使用专门的编码器来序列化对象以进行处理或通过网络传输。
* 虽然编码器和标准序列化都负责将对象转换为字节,但编码器是动态生成的代码,并使用一种格式,允许 Spark 执行许多操作,
* 如过滤、排序和散列,而无需将字节反序列化回对象。
*/
// 注意 : User不能是static修饰
User user = new User("qiaozhanwei", 20);
Encoder<User> userEncoder = Encoders.bean(User.class);
Dataset<User> javaBeanDS = spark.createDataset(Collections.singletonList(user), userEncoder);
javaBeanDS.show();
Encoder<Integer> integerEncoder = Encoders.INT();
Dataset<Integer> primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder);
Dataset<Integer> transformedDS = primitiveDS.map(
(MapFunction<Integer, Integer>) value -> value + 1,
integerEncoder);
// java: 不兼容的类型: java.lang.Object无法转换为java.lang.Integer[],跑不通
// Integer[] collect = transformedDS.collect();
transformedDS.show();
Dataset<User> userDS = spark.read().json("datas/sql/user.json").as(userEncoder);
userDS.show();
JavaRDD<User> userRDD = spark.read().textFile("datas/sql/user.json")
.javaRDD()
.map(line -> {
User userInfo = JSON.parseObject(line, User.class);
return userInfo;
});
Dataset<Row> user3DF = spark.createDataFrame(userRDD, User.class);
user3DF.createOrReplaceTempView("user3");
List<User> userList = new ArrayList<>();
userList.add(new User("haha", 30));
Dataset<Row> dataFrame = spark.createDataFrame(userList, User.class);
dataFrame.show();
Dataset<Row> teenagerDF = spark.sql("select * from user3 where age between 13 and 20");
Encoder<String> stringEncoder = Encoders.STRING();
Dataset<String> teenagerNamesByIndexDF = teenagerDF.map(new MapFunction<Row, String>() {
@Override
public String call(Row value) throws Exception {
return "Name : " + value.getString(1);
}
}, stringEncoder);
teenagerNamesByIndexDF.show();
Dataset<String> teenagerNamesByFieldDF = teenagerDF.map(
(MapFunction<Row, String>) row -> "Name: " + row.<String>getAs("userName"),
stringEncoder);
teenagerNamesByFieldDF.show();
// 定义用户名字段类型
StructField userNameField = DataTypes.createStructField("name", DataTypes.StringType, true);
// 定义年龄字段类型
StructField ageField = DataTypes.createStructField("age", DataTypes.IntegerType, true);
List<StructField> fields = new ArrayList<>();
fields.add(userNameField);
fields.add(ageField);
StructType schema = DataTypes.createStructType(fields);
JavaRDD<String> user2RDD = spark.sparkContext().textFile("datas/sql/user.txt", 2).toJavaRDD();
JavaRDD<Row> rowRDD = user2RDD.map(new Function<String, Row>() {
@Override
public Row call(String value) throws Exception {
String[] fields = value.split(",");
return RowFactory.create(fields[0], Integer.parseInt(fields[1]));
}
});
Dataset<Row> user4DF = spark.createDataFrame(rowRDD, schema);
user4DF.createOrReplaceTempView("user4");
spark.sql("select * from user4").show();
spark.stop();
}
}
RDD、DataFrame和Dataset的关系及转换
2、UDF函数
标量函数
package com.journey.sql;
import com.journey.sql.bean.User;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;
import java.util.ArrayList;
import java.util.List;
public class ScalarFunctionTest {
public static void main(String[] args) {
SparkSession spark = SparkSession
.builder()
.appName("ScalarFunctionTest")
.master("local[*]")
.getOrCreate();
// 根据参数有UDF2....
UDF1<String, String> myUdf = new UDF1<String, String>() {
@Override
public String call(String value) throws Exception {
return "baidu-" + value;
}
};
// 函数注册
spark.udf().register("myUdf", myUdf, DataTypes.StringType);
List<User> userList = new ArrayList<>();
userList.add(new User("zhangsan", 20));
Dataset<Row> df = spark.createDataFrame(userList, User.class);
df.createOrReplaceTempView("user");
spark.sql("select myUdf(userName) from user").show();
spark.stop();
}
}
聚合函数
弱类型
package com.journey.sql;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import java.util.ArrayList;
import java.util.List;
// 无类型
public class MyAverage1 extends UserDefinedAggregateFunction {
private StructType inputSchema;
private StructType bufferSchema;
public MyAverage1() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);
List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
}
@Override
public StructType inputSchema() {
return inputSchema;
}
@Override
public StructType bufferSchema() {
return bufferSchema;
}
@Override
public DataType dataType() {
return DataTypes.DoubleType;
}
// 此函数是否总是在相同的输入上返回相同的输出
@Override
public boolean deterministic() {
return true;
}
// 初始化
@Override
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
}
// 中间状态更新
@Override
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
}
// 合并
@Override
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
}
// 计算结果
@Override
public Double evaluate(Row buffer) {
return (double) buffer.getLong(0) / buffer.getLong(1);
}
}
强类型
package com.journey.sql;
import java.io.Serializable;
public class Average implements Serializable {
private long sum;
private long count;
public Average() {}
public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}
public long getSum() {
return sum;
}
public void setSum(long sum) {
this.sum = sum;
}
public long getCount() {
return count;
}
public void setCount(long count) {
this.count = count;
}
}
package com.journey.sql;
import com.journey.sql.bean.User;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;
public class MyAverage2 extends Aggregator<User, Average, Double> {
@Override
public Average zero() {
return new Average(0L, 0L);
}
@Override
public Average reduce(Average buffer, User user) {
long newSum = buffer.getSum() + user.getAge();
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
@Override
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
@Override
public Double finish(Average reduction) {
return (double) reduction.getSum() / reduction.getCount();
}
@Override
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
package com.journey.sql;
import com.journey.sql.bean.User;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;
import java.util.ArrayList;
import java.util.List;
public class AggregationsTest {
public static void main(String[] args) throws Exception {
SparkSession spark = SparkSession
.builder()
.appName("AggregationsTest")
.master("local[*]")
.getOrCreate();
spark.udf().register("myAverage1", new MyAverage1());
// 不能进行注册 ?必须使用DSL语法调用
// spark.udf().register("MyAverge2", new MyAverage2());
List<User> userList = new ArrayList<>();
userList.add(new User("qiaozhanwei", 34));
userList.add(new User("zhangsan", 34));
Dataset<Row> df = spark.createDataFrame(userList, User.class);
df.createOrReplaceTempView("user");
spark.sql("select myAverage1(age) from user").show();
MyAverage2 myAverage = new MyAverage2();
TypedColumn<User, Double> averageAge = myAverage.toColumn().name("average_age");
Encoder<User> userEncoder = Encoders.bean(User.class);
Dataset<Double> average = df.as(userEncoder).select(averageAge);
average.show();
spark.stop();
}
}
如感兴趣,点赞加关注,谢谢!!!
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。