1 Op与Kernel的注册

继续追踪执行流程会发现,ReluFunctor在构造UserOpExpr时会用到UserOpRegistryMgr管理的Op与Kernel。Op表示算子的描述信息,Kernel实现在不同设备上的计算。

注册信息保存在私有的map变量中。UserOpRegistryMgr的头文件中定义了3个宏,分别用于注册op、grad_op、kernel。

1.1 ReluOp的注册

REGISTER_USER_OP负责UserOp的注册。通过检索代码可以找到这个宏的使用场景。ReluOp相关的源代码在这3个文件中:

  • class定义: build/oneflow/core/framework/op_generated.h
  • 注册op、op的部分实现: build/oneflow/core/framework/op_generated.cpp
  • 主要实现: oneflow/oneflow/user/ops/relu_op.cpp

REGISTER_USER_OP宏在op_generated.cpp中展开后代码如下:

static UserOpRegisterTrigger<OpRegistry> g_register_trigger715 =
  ::oneflow::user_op::UserOpRegistryMgr::Get()
  .CheckAndGetOpRegistry("relu")
  .Input("x")
  .Output("y")
  .SetGetSbpFn(&ReluOp::GetSbp)
  .SetLogicalTensorDescInferFn(&ReluOp::InferLogicalTensorDesc)
  .SetPhysicalTensorDescInferFn(&ReluOp::InferPhysicalTensorDesc)
  .SetDataTypeInferFn(&ReluOp::InferDataType);

调用流程如下:
op与kernel注册

CheckAndGetOpRegistry会创建一个OpRegistry对象,这个类和UserOpRegisterTrigger类一样,只是为构造OpRegistryResult用的中间类型。OpRegistry会暂存中间结果并在Finish中设置一些默认推导逻辑。UserOpRegisterTrigger的构造函数会调用注册逻辑。静态变量就是为了触发构造函数从而调用注册逻辑,将构造好的OpRegistryResult保存到UserOpRegistryMgr(key是op_type,如relu)。

ReluOp表示一个具体的op_type,负责为OpRegistryResult提供Op特有的方法。

OpRegistryResult把不同的Op抽象为一个通用的结构(便于统一注册管理),主要包含描述信息,保存了op的输入输出描述,以及数据类型、sbp等的推导逻辑函数。对于relu来说,主要是记录了几个推导函数要调用ReluOp的静态方法;op_def主要包含input/output的名字。

1.2 ReluKernel的注册

ReluKernel在relu_kernel.cpp中注册,过程和Op的注册类似。REGISTER_USER_KERNEL宏产开后如下所示:

static UserOpRegisterTrigger<OpKernelRegistry> g_register_trigger0 =
  UserOpRegistryMgr::Get().
    CheckAndGetOpKernelRegistry("relu").
    // 通过模版参数指定kernel为ReluKernel
    SetCreateFn<ReluKernel>().
    // 参数不是bool表达式,应该是一个高阶表达式对象
    SetIsMatchedHob(ReluPrimitiveExists() == true);

注意SetCreateFn只是把一个如下的lambda表达式赋值给result_.create_fn,这个字段很重要,后续执行就是通过它获取kernel。

[] () -> const OpKernel* { return NewOpKernel<T>(); }

对于relu来说,NewOpKernel就是new一个ReluKernel对象并返回指针。

最终注册的结果,会把OpKernelRegistryResult保存到UserOpRegistryMgr(key是op_type,如relu)。

1.3 Op和Kernel注册相关的类关系图

op与kernel注册相关的类

2 UserOpExpr的构造

上一篇提到,functional_api.yaml.cpp中的functional::Relu函数通过find("Relu")获取预先注册的PackedFunctor<impl::ReluFunctor>,调用其call方法会执行impl::ReluFunctor

ReluFunctor的核心代码如下:

class ReluFunctor {
 public:
  ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); }
  Maybe<Tensor> operator()(const std::shared_ptr<Tensor>& x, bool inplace) const {
    // 忽略inplace相关逻辑
    return OpInterpUtil::Dispatch<Tensor>(*op_, {x});
  }
 private:
  std::shared_ptr<OpExpr> op_;
};

ReluFunctor的构造函数中,主要是构造UserOpExprUserOpExpr可以看作user op type的子概念,relu只有一个,scalar_add等有多个UserOpExpr

构造函数内的调用顺序如下:
UserOpExpr的构造

OpBuilderInput/Output调用主要是构造UserOpConf对象,Build函数内会修改UserOpConf对象,比如根据OpRegistryResult::op_def补充默认值到attr。之后构造UserOpExpr对象,UserOpConf对象被保存到UserOpExpr的父类BuiltinOpExprImpl<UserOpConf>op_proto_字段,对于relu来说,op_proto_主要保存input, output等信息。UserOpExpr初始化时会从OpRegistryResult拷贝函数变量。

3 Functor的执行

ReluFunctor执行的核心逻辑是调用OpInterpUtil::Dispatch。调运顺序如下:
ReluFunctor的执行

整个链路很长,本篇笔记只对前半部分的重点内容做一些说明。

3.1 根据环境和输入选择解释器

Dispatch调用的GetInterpreter返回的是一个AutogradInterpreter对象,这个类是在其内含的OpExprInterpreter成员变量基础之上增加了autograd的功能。GetInterpreter内实际构造的是以下3种Interpreter,在Build函数返回时转为AutogradInterpreter

  • LazyInterpreter: 应该用于lazy执行模式
  • EagerMirroredInterpreter: 貌似是单机(或数据并行)的动态图执行模式
  • EagerConsistentInterpreter: 分布式的动态图执行模式

各个Interpreter的关系如下:
Interpreter类之间的关系

GetInterpreter的作用是根据输入和环境等信息,选择一个合适的解释器。

接着在Dispatch中调用解释器的AutogradInterpreter::Apply方法,在这个方法内调用internal_->Apply(...),也就是上述3个解释器的Apply方法。

EagerConsistentInterpreter为例。这个类并没有定义Apply方法,实际调用的是父类方法EagerInterpreter::Apply。这个方法中调用一系列的APPLY_IF宏,就是用dynamic_cast判断op_expr的类型,类型合适才执行,对于relu会调用UserOpExpr版的ApplyImpl方法。

3.2 装饰器

EagerConsistentInterpreter::ApplyImpl的相关代码如下所示:

// Decorator的模版参数可以通过func的类型推断
// oneflow/core/common/decorator.h
template<template<typename...> class Decorator>
struct WithDecorator final {
  template<typename T, typename = void> struct Decorate;
  template<typename T, typename... Args>
  struct Decorate<T (*)(Args...)> final {
    template<T (*func)(Args...)>
    static T Call(Args... args) {
      return Decorator<T, Args...>::template Call<func>(args...);
    }
  };
};

// oneflow/core/framework/tensor_consistent_id.h
template<typename Arg0, typename Arg1, typename... Args>
struct NonRecursiveInitConsistentId<Maybe<void>, Arg0, Arg1, TensorTuple*, Args...> {
  template<Maybe<void> (*func)(Arg0, Arg1, TensorTuple*, Args...)>
  static Maybe<void> Call(Arg0 arg0, Arg1 arg1, TensorTuple* outputs, Args... args) {
    // ...
    Maybe<void> ret = func(arg0, arg1, outputs, args...);
    // ...
    return ret;
  }
};

// 宏展开
// 去掉模版参数后就是 &WithDecorator::Decorate::Call
auto* InterpretThenInitConsistentId =
(&WithDecorator<NonRecursiveInitConsistentId>::Decorate<__decltype(&Interpret)>::Call<&Interpret>);

Maybe<void> EagerConsistentInterpreter::ApplyImpl(const UserOpExpr& op_expr,
                                                  const TensorTuple& inputs, TensorTuple* outputs,
                                                  const OpExprInterpContext& ctx) const {
  return InterpretThenInitConsistentId(op_expr, inputs, outputs, ctx);
}

InterpretThenInitConsistentId是匿名命名空间中通过宏定义的一个函数指针,如果将其中的模版参数都去掉,简化后就是函数指针&WithDecorator::Decorate::Call。也就是说,ApplyImpl函数直接把任务转发给WithDecorator::Decorate::Call,再转发给NonRecursiveInitConsistentId::Call。函数Interpret的类型决定了其余模版参数的推断,它就是模版定义中的func,在NonRecursiveInitConsistentId::Call中实际调用的就是InterpretNonRecursiveInitConsistentIdInterprete外面套了一层,主要做传输token等处理。

这是典型的Decorator模式,巧妙地通过精心设计的模版解决众多场景的逻辑处理。

WithDecorator的作用主要是将具体的Decorator与调用环境解耦,可以支持多个Decorator的组合。例如GetBoxingOutput就组合了2个装饰器。

3.3 Interpret的执行

Interpret核心代码放如下(稍微调整以便于演示):

// oneflow/core/framework/op_interpreter/eager_consistent_op_interpreter.cpp
Maybe<void> Interpret(const UserOpExpr& user_op_expr, const TensorTuple& inputs,
                           const Symbol<Device>& default_device, TensorTuple* outputs,
                           const OpExprInterpContext& ctx) {
  // ...
  std::shared_ptr<const ConsistentTensorInferResult> result =
    JUST(user_op_expr.mut_consistent_tensor_infer_cache()->GetOrInfer(*infer_args));
  // ...
  const auto& kernel = JUST(user_op_expr.MutKernel4Stream(result->stream()));
  // ...
  JUST(PhysicalRun([&](InstructionsBuilder* builder) -> Maybe<void> {
    return builder->LocalCallOpKernel(kernel, input_eager_blob_objects, output_eager_blob_objects,
                                      ctx, stream);
  }));
}

MutKernel4Stream获取kernel的调用顺序如下:
MutKernel4Stream获取kernel的调用顺序

Interpret中的result对象是TensorInferResult,而不是outputs。它的获取过程比较复杂,我们只关注它的stream字段。relu的inputs不是空的,result.stream会被设置为inputs[0]的device,默认就是CPU。这个stream变量的类型是oneflow::Stream。(和后面的指令Stream是不同类型)

Interpret中出现的user_op_expr.MutKernel4Stream函数调用有几点需要说明:

  • 这个函数是UserOpExpr的成员方法,负责修改stream2kernel_成员变量。这个map为op维护从设备到StatefulLocalOpKernel的映射。
  • MutKernel4Stream函数中对UserOpExpr的map成员的修改不存在数据竞争问题user_op_expr是保存在ReluFunctor中,而ReluFunctor保存在functional_api.yaml.cpp中的functional::Relu函数的静态op__变量中(作为lambda捕获的一部分),op__是thread_local的,所以不存在数据竞争问题。
  • StatefulLocalOpKernel本身并没有实现kernel计算逻辑,它只是保存kernel的一些信息。后面会看到,在生成虚拟机指令时会调用它的ChooseOpKernel方法设置实际的kernel。
  • BuildOpConf基于UserOpExpr的proto为kernel生成配置,增加的字段主要是device_tag,这也是来自result.stream,对relu来说也就是inputs[0]的stream。

Interpret最终会构造一个lambda表达式并传给PhysicalRun,构造指令并提交虚拟机调度执行。

参考资料


郑建华
1 声望4 粉丝