本文主要研究一下flink Table的AggregateFunction

实例

/**
 * Accumulator for WeightedAvg.
 */
public static class WeightedAvgAccum {
    public long sum = 0;
    public int count = 0;
}

/**
 * Weighted Average user-defined aggregate function.
 */
public static class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> {

    @Override
    public WeightedAvgAccum createAccumulator() {
        return new WeightedAvgAccum();
    }

    @Override
    public Long getValue(WeightedAvgAccum acc) {
        if (acc.count == 0) {
            return 0L;
        } else {
            return acc.sum / acc.count;
        }
    }

    public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum += iValue * iWeight;
        acc.count += iWeight;
    }

    public void retract(WeightedAvgAccum acc, long iValue, int iWeight) {
        acc.sum -= iValue * iWeight;
        acc.count -= iWeight;
    }
    
    public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) {
        Iterator<WeightedAvgAccum> iter = it.iterator();
        while (iter.hasNext()) {
            WeightedAvgAccum a = iter.next();
            acc.count += a.count;
            acc.sum += a.sum;
        }
    }
    
    public void resetAccumulator(WeightedAvgAccum acc) {
        acc.count = 0;
        acc.sum = 0L;
    }
}

// register function
BatchTableEnvironment tEnv = ...
tEnv.registerFunction("wAvg", new WeightedAvg());

// use function
tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user");
  • WeightedAvg继承了AggregateFunction,实现了getValue、accumulate、retract、merge、resetAccumulator方法

AggregateFunction

flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/functions/AggregateFunction.scala

abstract class AggregateFunction[T, ACC] extends UserDefinedFunction {
  /**
    * Creates and init the Accumulator for this [[AggregateFunction]].
    *
    * @return the accumulator with the initial value
    */
  def createAccumulator(): ACC

  /**
    * Called every time when an aggregation result should be materialized.
    * The returned value could be either an early and incomplete result
    * (periodically emitted as data arrive) or the final result of the
    * aggregation.
    *
    * @param accumulator the accumulator which contains the current
    *                    aggregated results
    * @return the aggregation result
    */
  def getValue(accumulator: ACC): T

    /**
    * Returns true if this AggregateFunction can only be applied in an OVER window.
    *
    * @return true if the AggregateFunction requires an OVER window, false otherwise.
    */
  def requiresOver: Boolean = false

  /**
    * Returns the TypeInformation of the AggregateFunction's result.
    *
    * @return The TypeInformation of the AggregateFunction's result or null if the result type
    *         should be automatically inferred.
    */
  def getResultType: TypeInformation[T] = null

  /**
    * Returns the TypeInformation of the AggregateFunction's accumulator.
    *
    * @return The TypeInformation of the AggregateFunction's accumulator or null if the
    *         accumulator type should be automatically inferred.
    */
  def getAccumulatorType: TypeInformation[ACC] = null
}
  • AggregateFunction继承了UserDefinedFunction;它有两个泛型,一个T表示value的泛型,一个ACC表示Accumulator的泛型;它定义了createAccumulator、getValue、getResultType、getAccumulatorType方法(这几个方法中子类必须实现createAccumulator、getValue方法)
  • 对于AggregateFunction,有一个accumulate方法这里没定义,但是需要子类定义及实现,该方法接收ACC,T等参数,返回void;另外还有retract、merge、resetAccumulator三个方法是可选的,需要子类根据情况去定义及实现
  • 对于datastream bounded over aggregate操作,要求实现restract方法,该方法接收ACC,T等参数,返回void;对于datastream session window grouping aggregate以及dataset grouping aggregate操作,要求实现merge方法,该方法接收ACC,java.lang.Iterable<T>两个参数,返回void;对于dataset grouping aggregate操作,要求实现resetAccumulator方法,该方法接收ACC参数,返回void

DataSetPreAggFunction

flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala

class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction)
  extends AbstractRichFunction
  with GroupCombineFunction[Row, Row]
  with MapPartitionFunction[Row, Row]
  with Compiler[GeneratedAggregations]
  with Logging {

  private var output: Row = _
  private var accumulators: Row = _

  private var function: GeneratedAggregations = _

  override def open(config: Configuration) {
    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
                s"Code:\n$genAggregations.code")
    val clazz = compile(
      getRuntimeContext.getUserCodeClassLoader,
      genAggregations.name,
      genAggregations.code)
    LOG.debug("Instantiating AggregateHelper.")
    function = clazz.newInstance()

    output = function.createOutputRow()
    accumulators = function.createAccumulators()
  }

  override def combine(values: Iterable[Row], out: Collector[Row]): Unit = {
    // reset accumulators
    function.resetAccumulator(accumulators)

    val iterator = values.iterator()

    var record: Row = null
    while (iterator.hasNext) {
      record = iterator.next()
      // accumulate
      function.accumulate(accumulators, record)
    }

    // set group keys and accumulators to output
    function.setAggregationResults(accumulators, output)
    function.setForwardedFields(record, output)

    out.collect(output)
  }

  override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit = {
    combine(values, out)
  }
}
  • DataSetPreAggFunction的combine方法会调用function.accumulate(accumulators, record),其中accumulators为Row[WeightedAvgAccum]类型,record为Row类型;function为生成的类,它继承了GeneratedAggregations,其code在genAggregations中,而genAggregations则由AggregateUtil.createDataSetAggregateFunctions方法生成,它会去调用WeightedAvg的accumulate方法

GeneratedAggregations

flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/GeneratedAggregations.scala

abstract class GeneratedAggregations extends Function {

  /**
    * Setup method for [[org.apache.flink.table.functions.AggregateFunction]].
    * It can be used for initialization work. By default, this method does nothing.
    *
    * @param ctx The runtime context.
    */
  def open(ctx: RuntimeContext)

  /**
    * Sets the results of the aggregations (partial or final) to the output row.
    * Final results are computed with the aggregation function.
    * Partial results are the accumulators themselves.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param output       output results collected in a row
    */
  def setAggregationResults(accumulators: Row, output: Row)

  /**
    * Copies forwarded fields, such as grouping keys, from input row to output row.
    *
    * @param input        input values bundled in a row
    * @param output       output results collected in a row
    */
  def setForwardedFields(input: Row, output: Row)

  /**
    * Accumulates the input values to the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param input        input values bundled in a row
    */
  def accumulate(accumulators: Row, input: Row)

  /**
    * Retracts the input values from the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    * @param input        input values bundled in a row
    */
  def retract(accumulators: Row, input: Row)

  /**
    * Initializes the accumulators and save them to a accumulators row.
    *
    * @return a row of accumulators which contains the aggregated results
    */
  def createAccumulators(): Row

  /**
    * Creates an output row object with the correct arity.
    *
    * @return an output row object with the correct arity.
    */
  def createOutputRow(): Row

  /**
    * Merges two rows of accumulators into one row.
    *
    * @param a First row of accumulators
    * @param b The other row of accumulators
    * @return A row with the merged accumulators of both input rows.
    */
  def mergeAccumulatorsPair(a: Row, b: Row): Row

  /**
    * Resets all the accumulators.
    *
    * @param accumulators the accumulators (saved in a row) which contains the current
    *                     aggregated results
    */
  def resetAccumulator(accumulators: Row)

  /**
    * Cleanup for the accumulators.
    */
  def cleanup()

  /**
    * Tear-down method for [[org.apache.flink.table.functions.AggregateFunction]].
    * It can be used for clean up work. By default, this method does nothing.
    */
  def close()
}
  • GeneratedAggregations定义了accumulate(accumulators: Row, input: Row)、resetAccumulator(accumulators: Row)等方法

AggregateUtil

flink-table_2.11-1.7.1-sources.jar!/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala

object AggregateUtil {

  type CalcitePair[T, R] = org.apache.calcite.util.Pair[T, R]
  type JavaList[T] = java.util.List[T]

  //......

  /**
    * Create functions to compute a [[org.apache.flink.table.plan.nodes.dataset.DataSetAggregate]].
    * If all aggregation functions support pre-aggregation, a pre-aggregation function and the
    * respective output type are generated as well.
    */
  private[flink] def createDataSetAggregateFunctions(
      generator: AggregationCodeGenerator,
      namedAggregates: Seq[CalcitePair[AggregateCall, String]],
      inputType: RelDataType,
      inputFieldTypeInfo: Seq[TypeInformation[_]],
      outputType: RelDataType,
      groupings: Array[Int],
      tableConfig: TableConfig): (
        Option[DataSetPreAggFunction],
        Option[TypeInformation[Row]],
        Either[DataSetAggFunction, DataSetFinalAggFunction]) = {

    val needRetract = false
    val (aggInFields, aggregates, isDistinctAggs, accTypes, _) = transformToAggregateFunctions(
      namedAggregates.map(_.getKey),
      inputType,
      needRetract,
      tableConfig)

    val (gkeyOutMapping, aggOutMapping) = getOutputMappings(
      namedAggregates,
      groupings,
      inputType,
      outputType
    )

    val aggOutFields = aggOutMapping.map(_._1)

    if (doAllSupportPartialMerge(aggregates)) {

      // compute preaggregation type
      val preAggFieldTypes = gkeyOutMapping.map(_._2)
        .map(inputType.getFieldList.get(_).getType)
        .map(FlinkTypeFactory.toTypeInfo) ++ accTypes
      val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*)

      val genPreAggFunction = generator.generateAggregations(
        "DataSetAggregatePrepareMapHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggregates.indices.map(_ + groupings.length).toArray,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = true,
        groupings,
        None,
        groupings.length + aggregates.length,
        needRetract,
        needMerge = false,
        needReset = true,
        None
      )

      // compute mapping of forwarded grouping keys
      val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) {
        val gkeyOutFields = gkeyOutMapping.map(_._1)
        val mapping = Array.fill[Int](gkeyOutFields.max + 1)(-1)
        gkeyOutFields.zipWithIndex.foreach(m => mapping(m._1) = m._2)
        mapping
      } else {
        new Array[Int](0)
      }

      val genFinalAggFunction = generator.generateAggregations(
        "DataSetAggregateFinalHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggOutFields,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = false,
        gkeyMapping,
        Some(aggregates.indices.map(_ + groupings.length).toArray),
        outputType.getFieldCount,
        needRetract,
        needMerge = true,
        needReset = true,
        None
      )

      (
        Some(new DataSetPreAggFunction(genPreAggFunction)),
        Some(preAggRowType),
        Right(new DataSetFinalAggFunction(genFinalAggFunction))
      )
    }
    else {
      val genFunction = generator.generateAggregations(
        "DataSetAggregateHelper",
        inputFieldTypeInfo,
        aggregates,
        aggInFields,
        aggOutFields,
        isDistinctAggs,
        isStateBackedDataViews = false,
        partialResults = false,
        groupings,
        None,
        outputType.getFieldCount,
        needRetract,
        needMerge = false,
        needReset = true,
        None
      )

      (
        None,
        None,
        Left(new DataSetAggFunction(genFunction))
      )
    }

  }

  //......
}
  • AggregateUtil的createDataSetAggregateFunctions方法主要是生成GeneratedAggregationsFunction,然后创建DataSetPreAggFunction或DataSetAggFunction;之所以动态生成code,主要是用户自定义的诸如accumulate方法的参数是动态的,而flink代码是基于GeneratedAggregations定义的accumulate(accumulators: Row, input: Row)方法来调用,因此动态生成的code用于适配,在accumulate(accumulators: Row, input: Row)方法里头将Row转换为调用用户定义的accumulate方法所需的参数,然后调用用户定义的accumulate方法

小结

  • AggregateFunction继承了UserDefinedFunction;它有两个泛型,一个T表示value的泛型,一个ACC表示Accumulator的泛型;它定义了createAccumulator、getValue、getResultType、getAccumulatorType方法(这几个方法中子类必须实现createAccumulator、getValue方法);对于AggregateFunction,有一个accumulate方法这里没定义,但是需要子类定义及实现,该方法接收ACC,T等参数,返回void;另外还有retract、merge、resetAccumulator三个方法是可选的,需要子类根据情况去定义及实现(对于datastream bounded over aggregate操作,要求实现restract方法,该方法接收ACC,T等参数,返回void;对于datastream session window grouping aggregate以及dataset grouping aggregate操作,要求实现merge方法,该方法接收ACC,java.lang.Iterable<T>两个参数,返回void;对于dataset grouping aggregate操作,要求实现resetAccumulator方法,该方法接收ACC参数,返回void)
  • DataSetPreAggFunction的combine方法会调用function.accumulate(accumulators, record),其中accumulators为Row[WeightedAvgAccum]类型,record为Row类型;function为生成的类,它继承了GeneratedAggregations,其code在genAggregations中,而genAggregations则由AggregateUtil.createDataSetAggregateFunctions方法生成,它会去调用WeightedAvg的accumulate方法;GeneratedAggregations定义了accumulate(accumulators: Row, input: Row)、resetAccumulator(accumulators: Row)等方法
  • AggregateUtil的createDataSetAggregateFunctions方法主要是生成GeneratedAggregationsFunction,然后创建DataSetPreAggFunction或DataSetAggFunction;之所以动态生成code,主要是用户自定义的诸如accumulate方法的参数是动态的,而flink代码是基于GeneratedAggregations定义的accumulate(accumulators: Row, input: Row)方法来调用,因此动态生成的code用于适配,在accumulate(accumulators: Row, input: Row)方法里头将Row转换为调用用户定义的accumulate方法所需的参数,然后调用用户定义的accumulate方法

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...