1、基础操作

package com.journey.sql;

import com.alibaba.fastjson.JSON;
import com.journey.sql.bean.User;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.apache.spark.sql.functions.col;

public class SparkSQLTest {

    public static void main(String[] args) throws Exception {
        SparkSession spark = SparkSession
                .builder()
                .appName("Demo")
                .master("local[*]")
                .getOrCreate();
        // 读取 json 文件 创建 DataFrame {"username": "lisi","age": 18},DataFrame是一种特殊的Dataset,行是Row
        Dataset<Row> df = spark.read().json("datas/sql/user.json");
        // 展示表结构 + 数据
        df.show();
        // 打印schema结构
        df.printSchema();
        // 直接select
        df.select("username").show();
        // 加1
        df.select(col("username"), col("age").plus(1)).show();
        // 过滤age大于19
        df.filter(col("age").gt(19)).show();
        // 统计age的个数
        df.groupBy("age").count().show();

        df.createOrReplaceTempView("user");

        // 使用sql来查询
        Dataset<Row> sqlDF = spark.sql("select * from user");
        sqlDF.show();

        // 注册DataFrame作为一个全局的临时视图
        df.createGlobalTempView("user2");

        spark.sql("select * from global_temp.user2").show();
        spark.newSession().sql("select * from global_temp.user2").show();


        /**
         * 数据集与 RDD 类似,但是,它们不使用 Java 序列化或 Kryo,而是使用专门的编码器来序列化对象以进行处理或通过网络传输。
         * 虽然编码器和标准序列化都负责将对象转换为字节,但编码器是动态生成的代码,并使用一种格式,允许 Spark 执行许多操作,
         * 如过滤、排序和散列,而无需将字节反序列化回对象。
         */
        // 注意 : User不能是static修饰
        User user = new User("qiaozhanwei", 20);
        Encoder<User> userEncoder = Encoders.bean(User.class);
        Dataset<User> javaBeanDS = spark.createDataset(Collections.singletonList(user), userEncoder);
        javaBeanDS.show();

        Encoder<Integer> integerEncoder = Encoders.INT();
        Dataset<Integer> primitiveDS = spark.createDataset(Arrays.asList(1, 2, 3), integerEncoder);
        Dataset<Integer> transformedDS = primitiveDS.map(
                (MapFunction<Integer, Integer>) value -> value + 1,
                integerEncoder);
        // java: 不兼容的类型: java.lang.Object无法转换为java.lang.Integer[],跑不通
        // Integer[] collect = transformedDS.collect();
        transformedDS.show();

        Dataset<User> userDS = spark.read().json("datas/sql/user.json").as(userEncoder);
        userDS.show();

        JavaRDD<User> userRDD = spark.read().textFile("datas/sql/user.json")
                .javaRDD()
                .map(line -> {
                    User userInfo = JSON.parseObject(line, User.class);
                    return userInfo;
                });

        Dataset<Row> user3DF = spark.createDataFrame(userRDD, User.class);
        user3DF.createOrReplaceTempView("user3");

        List<User> userList = new ArrayList<>();
        userList.add(new User("haha", 30));

        Dataset<Row> dataFrame = spark.createDataFrame(userList, User.class);
        dataFrame.show();

        Dataset<Row> teenagerDF = spark.sql("select * from user3 where age between 13 and 20");

        Encoder<String> stringEncoder = Encoders.STRING();
        Dataset<String> teenagerNamesByIndexDF = teenagerDF.map(new MapFunction<Row, String>() {
            @Override
            public String call(Row value) throws Exception {
                return "Name : " + value.getString(1);
            }
        }, stringEncoder);
        teenagerNamesByIndexDF.show();

        Dataset<String> teenagerNamesByFieldDF = teenagerDF.map(
                (MapFunction<Row, String>) row -> "Name: " + row.<String>getAs("userName"),
                stringEncoder);
        teenagerNamesByFieldDF.show();

        // 定义用户名字段类型
        StructField userNameField = DataTypes.createStructField("name", DataTypes.StringType, true);
        // 定义年龄字段类型
        StructField ageField = DataTypes.createStructField("age", DataTypes.IntegerType, true);
        List<StructField> fields = new ArrayList<>();
        fields.add(userNameField);
        fields.add(ageField);

        StructType schema = DataTypes.createStructType(fields);

        JavaRDD<String> user2RDD = spark.sparkContext().textFile("datas/sql/user.txt", 2).toJavaRDD();
        JavaRDD<Row> rowRDD = user2RDD.map(new Function<String, Row>() {
            @Override
            public Row call(String value) throws Exception {
                String[] fields = value.split(",");
                return RowFactory.create(fields[0], Integer.parseInt(fields[1]));
            }
        });

        Dataset<Row> user4DF = spark.createDataFrame(rowRDD, schema);
        user4DF.createOrReplaceTempView("user4");

        spark.sql("select * from user4").show();


        spark.stop();
    }
}

RDD、DataFrame和Dataset的关系及转换
image.png

2、UDF函数

标量函数

package com.journey.sql;

import com.journey.sql.bean.User;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.types.DataTypes;

import java.util.ArrayList;
import java.util.List;

public class ScalarFunctionTest {

    public static void main(String[] args) {

        SparkSession spark = SparkSession
                .builder()
                .appName("ScalarFunctionTest")
                .master("local[*]")
                .getOrCreate();

        // 根据参数有UDF2....
        UDF1<String, String> myUdf = new UDF1<String, String>() {
            @Override
            public String call(String value) throws Exception {
                return "baidu-" + value;
            }
        };

        // 函数注册
        spark.udf().register("myUdf", myUdf, DataTypes.StringType);

        List<User> userList = new ArrayList<>();
        userList.add(new User("zhangsan", 20));

        Dataset<Row> df = spark.createDataFrame(userList, User.class);


        df.createOrReplaceTempView("user");

        spark.sql("select myUdf(userName) from user").show();

        spark.stop();

    }
}

聚合函数

弱类型

package com.journey.sql;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

import java.util.ArrayList;
import java.util.List;

// 无类型
public class MyAverage1 extends UserDefinedAggregateFunction {

    private StructType inputSchema;
    private StructType bufferSchema;

    public MyAverage1() {
        List<StructField> inputFields = new ArrayList<>();
        inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
        inputSchema = DataTypes.createStructType(inputFields);

        List<StructField> bufferFields = new ArrayList<>();
        bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
        bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
        bufferSchema = DataTypes.createStructType(bufferFields);

    }

    @Override
    public StructType inputSchema() {
        return inputSchema;
    }

    @Override
    public StructType bufferSchema() {
        return bufferSchema;
    }

    @Override
    public DataType dataType() {
        return DataTypes.DoubleType;
    }

    // 此函数是否总是在相同的输入上返回相同的输出
    @Override
    public boolean deterministic() {
        return true;
    }

    // 初始化
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        buffer.update(0, 0L);
        buffer.update(1, 0L);
    }

    // 中间状态更新
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        if (!input.isNullAt(0)) {
            long updatedSum = buffer.getLong(0) + input.getLong(0);
            long updatedCount = buffer.getLong(1) + 1;
            buffer.update(0, updatedSum);
            buffer.update(1, updatedCount);
        }
    }

    // 合并
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
        long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
        buffer1.update(0, mergedSum);
        buffer1.update(1, mergedCount);
    }

    // 计算结果
    @Override
    public Double evaluate(Row buffer) {
        return (double) buffer.getLong(0) / buffer.getLong(1);
    }
}

强类型

package com.journey.sql;

import java.io.Serializable;

public class Average implements Serializable {

    private long sum;
    private long count;

    public Average() {}

    public Average(long sum, long count) {
        this.sum = sum;
        this.count = count;
    }

    public long getSum() {
        return sum;
    }

    public void setSum(long sum) {
        this.sum = sum;
    }

    public long getCount() {
        return count;
    }

    public void setCount(long count) {
        this.count = count;
    }
}
package com.journey.sql;

import com.journey.sql.bean.User;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.expressions.Aggregator;


public class MyAverage2 extends Aggregator<User, Average, Double> {

    @Override
    public Average zero() {
        return new Average(0L, 0L);
    }

    @Override
    public Average reduce(Average buffer, User user) {
        long newSum = buffer.getSum() + user.getAge();
        long newCount = buffer.getCount() + 1;
        buffer.setSum(newSum);
        buffer.setCount(newCount);
        return buffer;
    }

    @Override
    public Average merge(Average b1, Average b2) {
        long mergedSum = b1.getSum() + b2.getSum();
        long mergedCount = b1.getCount() + b2.getCount();
        b1.setSum(mergedSum);
        b1.setCount(mergedCount);
        return b1;
    }

    @Override
    public Double finish(Average reduction) {
        return (double) reduction.getSum() / reduction.getCount();
    }

    @Override
    public Encoder<Average> bufferEncoder() {
        return Encoders.bean(Average.class);
    }

    @Override
    public Encoder<Double> outputEncoder() {
        return Encoders.DOUBLE();
    }
}
package com.journey.sql;

import com.journey.sql.bean.User;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.TypedColumn;

import java.util.ArrayList;
import java.util.List;

public class AggregationsTest {

    public static void main(String[] args) throws Exception {
        SparkSession spark = SparkSession
                .builder()
                .appName("AggregationsTest")
                .master("local[*]")
                .getOrCreate();

        spark.udf().register("myAverage1", new MyAverage1());
        // 不能进行注册 ?必须使用DSL语法调用
//        spark.udf().register("MyAverge2", new MyAverage2());

        List<User> userList = new ArrayList<>();
        userList.add(new User("qiaozhanwei", 34));
        userList.add(new User("zhangsan", 34));

        Dataset<Row> df = spark.createDataFrame(userList, User.class);
        df.createOrReplaceTempView("user");

        spark.sql("select myAverage1(age) from user").show();

        MyAverage2 myAverage = new MyAverage2();
        TypedColumn<User, Double> averageAge = myAverage.toColumn().name("average_age");
        Encoder<User> userEncoder = Encoders.bean(User.class);
        Dataset<Double> average = df.as(userEncoder).select(averageAge);
        average.show();

        spark.stop();
    }
}

如感兴趣,点赞加关注,谢谢!!!


journey
32 声望22 粉丝

« 上一篇
Spark原理
下一篇 »
Kafka原理