1

本文主要研究一下flink Table的select操作

Table.select

flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/api/table.scala

class Table(
    private[flink] val tableEnv: TableEnvironment,
    private[flink] val logicalPlan: LogicalNode) {

  //......

  def select(fields: String): Table = {
    val fieldExprs = ExpressionParser.parseExpressionList(fields)
    //get the correct expression for AggFunctionCall
    val withResolvedAggFunctionCall = fieldExprs.map(replaceAggFunctionCall(_, tableEnv))
    select(withResolvedAggFunctionCall: _*)
  }

  def replaceAggFunctionCall(field: Expression, tableEnv: TableEnvironment): Expression = {
    field match {
      case l: LeafExpression => l

      case u: UnaryExpression =>
        val c = replaceAggFunctionCall(u.child, tableEnv)
        u.makeCopy(Array(c))

      case b: BinaryExpression =>
        val l = replaceAggFunctionCall(b.left, tableEnv)
        val r = replaceAggFunctionCall(b.right, tableEnv)
        b.makeCopy(Array(l, r))
      // Functions calls
      case c @ Call(name, args) =>
        val function = tableEnv.getFunctionCatalog.lookupFunction(name, args)
        function match {
          case a: AggFunctionCall => a
          case a: Aggregation => a
          case p: AbstractWindowProperty => p
          case _ =>
            val newArgs =
              args.map(
                (exp: Expression) =>
                  replaceAggFunctionCall(exp, tableEnv))
            c.makeCopy(Array(name, newArgs))
        }
      // Scala functions
      case sfc @ ScalarFunctionCall(clazz, args) =>
        val newArgs: Seq[Expression] =
          args.map(
            (exp: Expression) =>
              replaceAggFunctionCall(exp, tableEnv))
        sfc.makeCopy(Array(clazz, newArgs))

      // Array constructor
      case c @ ArrayConstructor(args) =>
        val newArgs =
          c.elements
            .map((exp: Expression) => replaceAggFunctionCall(exp, tableEnv))
        c.makeCopy(Array(newArgs))

      // Other expressions
      case e: Expression => e
    }
  }

  def select(fields: Expression*): Table = {
    val expandedFields = expandProjectList(fields, logicalPlan, tableEnv)
    val (aggNames, propNames) = extractAggregationsAndProperties(expandedFields, tableEnv)
    if (propNames.nonEmpty) {
      throw new ValidationException("Window properties can only be used on windowed tables.")
    }

    if (aggNames.nonEmpty) {
      val projectsOnAgg = replaceAggregationsAndProperties(
        expandedFields, tableEnv, aggNames, propNames)
      val projectFields = extractFieldReferences(expandedFields)

      new Table(tableEnv,
        Project(projectsOnAgg,
          Aggregate(Nil, aggNames.map(a => Alias(a._1, a._2)).toSeq,
            Project(projectFields, logicalPlan).validate(tableEnv)
          ).validate(tableEnv)
        ).validate(tableEnv)
      )
    } else {
      new Table(tableEnv,
        Project(expandedFields.map(UnresolvedAlias), logicalPlan).validate(tableEnv))
    }
  }

  //......
}
  • Table提供了两个select方法,一个接收String参数,一个接收Expression参数
  • String参数的select内部先调用ExpressionParser.parseExpressionList解析String,之后再通过replaceAggFunctionCall替换UDAGG function,最后再调用Expression参数的select方法
  • Expression参数的select方法会使用Project重新创建Table,如果有aggregate的话,会创建Aggregate,然后再通过Project包装

Expression

flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/expressions/Expression.scala

abstract class Expression extends TreeNode[Expression] {
  /**
    * Returns the [[TypeInformation]] for evaluating this expression.
    * It is sometimes not available until the expression is valid.
    */
  private[flink] def resultType: TypeInformation[_]

  /**
    * One pass validation of the expression tree in post order.
    */
  private[flink] lazy val valid: Boolean = childrenValid && validateInput().isSuccess

  private[flink] def childrenValid: Boolean = children.forall(_.valid)

  /**
    * Check input data types, inputs number or other properties specified by this expression.
    * Return `ValidationSuccess` if it pass the check,
    * or `ValidationFailure` with supplement message explaining the error.
    * Note: we should only call this method until `childrenValid == true`
    */
  private[flink] def validateInput(): ValidationResult = ValidationSuccess

  /**
    * Convert Expression to its counterpart in Calcite, i.e. RexNode
    */
  private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode =
    throw new UnsupportedOperationException(
      s"${this.getClass.getName} cannot be transformed to RexNode"
    )

  private[flink] def checkEquals(other: Expression): Boolean = {
    if (this.getClass != other.getClass) {
      false
    } else {
      def checkEquality(elements1: Seq[Any], elements2: Seq[Any]): Boolean = {
        elements1.length == elements2.length && elements1.zip(elements2).forall {
          case (e1: Expression, e2: Expression) => e1.checkEquals(e2)
          case (t1: Seq[_], t2: Seq[_]) => checkEquality(t1, t2)
          case (i1, i2) => i1 == i2
        }
      }
      val elements1 = this.productIterator.toSeq
      val elements2 = other.productIterator.toSeq
      checkEquality(elements1, elements2)
    }
  }
}

abstract class BinaryExpression extends Expression {
  private[flink] def left: Expression
  private[flink] def right: Expression
  private[flink] def children = Seq(left, right)
}

abstract class UnaryExpression extends Expression {
  private[flink] def child: Expression
  private[flink] def children = Seq(child)
}

abstract class LeafExpression extends Expression {
  private[flink] val children = Nil
}
  • Expression继承了TreeNode,它有三个抽象子类分别是BinaryExpression、UnaryExpression、LeafExpression

Project

flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/plan/logical/operators.scala

case class Project(
    projectList: Seq[NamedExpression],
    child: LogicalNode,
    explicitAlias: Boolean = false)
  extends UnaryNode {

  override def output: Seq[Attribute] = projectList.map(_.toAttribute)

  override def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
    val afterResolve = super.resolveExpressions(tableEnv).asInstanceOf[Project]
    val newProjectList =
      afterResolve.projectList.zipWithIndex.map { case (e, i) =>
        e match {
          case u @ UnresolvedAlias(c) => c match {
            case ne: NamedExpression => ne
            case expr if !expr.valid => u
            case c @ Cast(ne: NamedExpression, tp) => Alias(c, s"${ne.name}-$tp")
            case gcf: GetCompositeField => Alias(gcf, gcf.aliasName().getOrElse(s"_c$i"))
            case other => Alias(other, s"_c$i")
          }
          case _ =>
            throw new RuntimeException("This should never be called and probably points to a bug.")
        }
    }
    Project(newProjectList, child, explicitAlias)
  }

  override def validate(tableEnv: TableEnvironment): LogicalNode = {
    val resolvedProject = super.validate(tableEnv).asInstanceOf[Project]
    val names: mutable.Set[String] = mutable.Set()

    def checkName(name: String): Unit = {
      if (names.contains(name)) {
        failValidation(s"Duplicate field name $name.")
      } else {
        names.add(name)
      }
    }

    resolvedProject.projectList.foreach {
      case n: Alias =>
        // explicit name
        checkName(n.name)
      case r: ResolvedFieldReference =>
        // simple field forwarding
        checkName(r.name)
      case _ => // Do nothing
    }
    resolvedProject
  }

  override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
    child.construct(relBuilder)

    val exprs = if (explicitAlias) {
      projectList
    } else {
      // remove AS expressions, according to Calcite they should not be in a final RexNode
      projectList.map {
        case Alias(e: Expression, _, _) => e
        case e: Expression => e
      }
    }

    relBuilder.project(
      exprs.map(_.toRexNode(relBuilder)).asJava,
      projectList.map(_.name).asJava,
      true)
  }
}
  • Project继承了UnaryNode,它构造器接收Seq[NamedExpression]、LogicalNode、explicitAlias三个参数,其中explicitAlias可选,默认为false

Aggregate

flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/plan/logical/operators.scala

case class Aggregate(
    groupingExpressions: Seq[Expression],
    aggregateExpressions: Seq[NamedExpression],
    child: LogicalNode) extends UnaryNode {

  override def output: Seq[Attribute] = {
    (groupingExpressions ++ aggregateExpressions) map {
      case ne: NamedExpression => ne.toAttribute
      case e => Alias(e, e.toString).toAttribute
    }
  }

  override protected[logical] def construct(relBuilder: RelBuilder): RelBuilder = {
    child.construct(relBuilder)
    relBuilder.aggregate(
      relBuilder.groupKey(groupingExpressions.map(_.toRexNode(relBuilder)).asJava),
      aggregateExpressions.map {
        case Alias(agg: Aggregation, name, _) => agg.toAggCall(name)(relBuilder)
        case _ => throw new RuntimeException("This should never happen.")
      }.asJava)
  }

  override def validate(tableEnv: TableEnvironment): LogicalNode = {
    implicit val relBuilder: RelBuilder = tableEnv.getRelBuilder
    val resolvedAggregate = super.validate(tableEnv).asInstanceOf[Aggregate]
    val groupingExprs = resolvedAggregate.groupingExpressions
    val aggregateExprs = resolvedAggregate.aggregateExpressions
    aggregateExprs.foreach(validateAggregateExpression)
    groupingExprs.foreach(validateGroupingExpression)

    def validateAggregateExpression(expr: Expression): Unit = expr match {
      case distinctExpr: DistinctAgg =>
        distinctExpr.child match {
          case _: DistinctAgg => failValidation(
            "Chained distinct operators are not supported!")
          case aggExpr: Aggregation => validateAggregateExpression(aggExpr)
          case _ => failValidation(
            "Distinct operator can only be applied to aggregation expressions!")
        }
      // check aggregate function
      case aggExpr: Aggregation
        if aggExpr.getSqlAggFunction.requiresOver =>
        failValidation(s"OVER clause is necessary for window functions: [${aggExpr.getClass}].")
      // check no nested aggregation exists.
      case aggExpr: Aggregation =>
        aggExpr.children.foreach { child =>
          child.preOrderVisit {
            case agg: Aggregation =>
              failValidation(
                "It's not allowed to use an aggregate function as " +
                  "input of another aggregate function")
            case _ => // OK
          }
        }
      case a: Attribute if !groupingExprs.exists(_.checkEquals(a)) =>
        failValidation(
          s"expression '$a' is invalid because it is neither" +
            " present in group by nor an aggregate function")
      case e if groupingExprs.exists(_.checkEquals(e)) => // OK
      case e => e.children.foreach(validateAggregateExpression)
    }

    def validateGroupingExpression(expr: Expression): Unit = {
      if (!expr.resultType.isKeyType) {
        failValidation(
          s"expression $expr cannot be used as a grouping expression " +
            "because it's not a valid key type which must be hashable and comparable")
      }
    }
    resolvedAggregate
  }
}
  • Aggregate继承了UnaryNode,它构造器接收Seq[Expression]、Seq[NamedExpression]、LogicalNode三个参数

LogicalNode

flink-table_2.11-1.7.0-sources.jar!/org/apache/flink/table/plan/logical/LogicalNode.scala

abstract class LogicalNode extends TreeNode[LogicalNode] {
  def output: Seq[Attribute]

  def resolveExpressions(tableEnv: TableEnvironment): LogicalNode = {
    // resolve references and function calls
    val exprResolved = expressionPostOrderTransform {
      case u @ UnresolvedFieldReference(name) =>
        // try resolve a field
        resolveReference(tableEnv, name).getOrElse(u)
      case c @ Call(name, children) if c.childrenValid =>
        tableEnv.getFunctionCatalog.lookupFunction(name, children)
    }

    exprResolved.expressionPostOrderTransform {
      case ips: InputTypeSpec if ips.childrenValid =>
        var changed: Boolean = false
        val newChildren = ips.expectedTypes.zip(ips.children).map { case (tpe, child) =>
          val childType = child.resultType
          if (childType != tpe && TypeCoercion.canSafelyCast(childType, tpe)) {
            changed = true
            Cast(child, tpe)
          } else {
            child
          }
        }.toArray[AnyRef]
        if (changed) ips.makeCopy(newChildren) else ips
    }
  }

  final def toRelNode(relBuilder: RelBuilder): RelNode = construct(relBuilder).build()

  protected[logical] def construct(relBuilder: RelBuilder): RelBuilder

  def validate(tableEnv: TableEnvironment): LogicalNode = {
    val resolvedNode = resolveExpressions(tableEnv)
    resolvedNode.expressionPostOrderTransform {
      case a: Attribute if !a.valid =>
        val from = children.flatMap(_.output).map(_.name).mkString(", ")
        // give helpful error message for null literals
        if (a.name == "null") {
          failValidation(s"Cannot resolve field [${a.name}] given input [$from]. If you want to " +
            s"express a null literal, use 'Null(TYPE)' for typed null expressions. " +
            s"For example: Null(INT)")
        } else {
          failValidation(s"Cannot resolve field [${a.name}] given input [$from].")
        }

      case e: Expression if e.validateInput().isFailure =>
        failValidation(s"Expression $e failed on input check: " +
          s"${e.validateInput().asInstanceOf[ValidationFailure].message}")
    }
  }

  /**
    * Resolves the given strings to a [[NamedExpression]] using the input from all child
    * nodes of this LogicalPlan.
    */
  def resolveReference(tableEnv: TableEnvironment, name: String): Option[NamedExpression] = {
    // try to resolve a field
    val childrenOutput = children.flatMap(_.output)
    val fieldCandidates = childrenOutput.filter(_.name.equalsIgnoreCase(name))
    if (fieldCandidates.length > 1) {
      failValidation(s"Reference $name is ambiguous.")
    } else if (fieldCandidates.nonEmpty) {
      return Some(fieldCandidates.head.withName(name))
    }

    // try to resolve a table
    tableEnv.scanInternal(Array(name)) match {
      case Some(table) => Some(TableReference(name, table))
      case None => None
    }
  }

  /**
    * Runs [[postOrderTransform]] with `rule` on all expressions present in this logical node.
    *
    * @param rule the rule to be applied to every expression in this logical node.
    */
  def expressionPostOrderTransform(rule: PartialFunction[Expression, Expression]): LogicalNode = {
    var changed = false

    def expressionPostOrderTransform(e: Expression): Expression = {
      val newExpr = e.postOrderTransform(rule)
      if (newExpr.fastEquals(e)) {
        e
      } else {
        changed = true
        newExpr
      }
    }

    val newArgs = productIterator.map {
      case e: Expression => expressionPostOrderTransform(e)
      case Some(e: Expression) => Some(expressionPostOrderTransform(e))
      case seq: Traversable[_] => seq.map {
        case e: Expression => expressionPostOrderTransform(e)
        case other => other
      }
      case r: Resolvable[_] => r.resolveExpressions(e => expressionPostOrderTransform(e))
      case other: AnyRef => other
    }.toArray

    if (changed) makeCopy(newArgs) else this
  }

  protected def failValidation(msg: String): Nothing = {
    throw new ValidationException(msg)
  }
}

abstract class LeafNode extends LogicalNode {
  override def children: Seq[LogicalNode] = Nil
}

abstract class UnaryNode extends LogicalNode {
  def child: LogicalNode

  override def children: Seq[LogicalNode] = child :: Nil
}

abstract class BinaryNode extends LogicalNode {
  def left: LogicalNode
  def right: LogicalNode

  override def children: Seq[LogicalNode] = left :: right :: Nil
}
  • LogicalNode跟Expression一样,也继承了TreeNode;LogicalNode有三个抽象子类,分别是BinaryNode、UnaryNode、LeafNode

小结

  • Table提供了两个select方法,一个接收String参数,一个接收Expression参数;String参数的select内部先调用ExpressionParser.parseExpressionList解析String,之后再通过replaceAggFunctionCall替换UDAGG function,最后再调用Expression参数的select方法
  • Expression参数的select方法会使用Project重新创建Table,如果有aggregate的话,会创建Aggregate,然后在通过Project包装
  • Project及Aggregate都是case class,它们都继承了UnaryNode,UnaryNode是LogicalNode的子类;LogicalNode跟Expression一样,也继承了TreeNode;Expression有三个抽象子类分别是BinaryExpression、UnaryExpression、LeafExpression;LogicalNode也有三个抽象子类,分别是BinaryNode、UnaryNode、LeafNode

doc


codecraft
11.9k 声望2k 粉丝

当一个代码的工匠回首往事时,不因虚度年华而悔恨,也不因碌碌无为而羞愧,这样,当他老的时候,可以很自豪告诉世人,我曾经将代码注入生命去打造互联网的浪潮之巅,那是个很疯狂的时代,我在一波波的浪潮上留下...