作者:LogM
本文原载于 https://segmentfault.com/u/logm/articles ,不允许转载~
1. 前言
Tensorflow Lite 是 Tensorflow 移动端的版本。
有关于 Tensorflow 怎么添加自定义 op,网上有很多博客都讲到了,我就不介绍了。而 Tensorflow Lite 因为相对小众一些,所以网上关于添加自定义 op 的教程很少。
刚好最近因为项目需要,我在 Tensorflow Lite 中添加了几个自定义 op。我把我的思考过程以及修改步骤记录下来,方便有相同需求的同学参考。
我花了大篇幅记录思考过程和源码阅读过程,是希望给其他小伙伴一些启发,以后遇到类似的深度学习框架魔改的问题,可以不依赖网上教程。
不关心思考过程和源码阅读的小伙伴,可以直接跳到文章的最后,我把修改的步骤做了总结。
2. 源码来源
我使用源码是 Tensorflow v1.13.2
Tensorflow Lite 位于 tensorflow/lite
目录下。
3. 官方教程
官网也有关于 Tensorflow Lite 怎么添加自定义 op 的教程,详见官方地址。
官方教程把"怎么写自定义 op 的代码"讲得很清楚,遗憾的是没有详细说明怎么把这些新写的代码放入到工程中编译。
4. 进入正题
第1步,找到目标文件夹位置
首先我们要找到源码中放置自定义 op 的文件夹位置。有多种寻找的方式:
- tensorflow 源码的目录结构非常清楚,有过类似框架阅读经验的同学应该马上能猜出位置;
- 官方教程告诉我们,自定义 op 的代码要实现
Prepare
和Eval
这两个函数,那么我们使用 grep 命令查找有哪些代码文件中带有这两个函数。
最终,我们找到的位置是 tensorflow/lite/kernels
。
找到目标文件夹位置以后,把新增代码放入该文件夹就可以了吗?显然,没有这么简单。有几个方面需要考虑:
- 代码逻辑层面,新增代码的逻辑怎么与源码的逻辑连接起来;
- 编译层面,新增代码怎么参与编译。
第2步,新增代码的逻辑怎么与源码的逻辑连接起来?
有过类似深度学习框架阅读经验的同学应该很快能想到,对于"添加自定义op"这个操作,就是个"op注册"的过程,所以马上想到去寻找带"register"字样的文件。
而没有深度学习框架阅读经验的同学也不用慌,官方教程告诉我们,自定义op在使用前需要调用 AddCustom
函数。那么很明显,这个函数就起到了将自定义op的逻辑与源码逻辑连接起来的任务。所以使用 grep 命令查找有哪些代码文件中带有这个函数。
两种方式殊途同归,找到关键文件 tensorflow/lite/kernels/register.cc
。
// 文件:tensorflow/lite/kernels/register.cc
// 行数:22
namespace custom {
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
TfLiteRegistration* Register_RELU_1();
}
// 文件:tensorflow/lite/kernels/register.cc
// 行数:278
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
嘿嘿嘿,我们发现官方源码中也放了5个自定义op,而且官方偷懒把自定义op与内置op的注册过程写在了一起,那么我们来看看官方是怎么写自定义op的吧,比如 Relu1
这个。
// 文件:tensorflow/lite/kernels/relu1.cc
#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
namespace tflite {
namespace ops {
namespace custom {
namespace relu1 {
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
const TfLiteTensor* input = GetInput(context, node, 0);
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TfLiteTensor* output = GetOutput(context, node, 0);
output->type = input->type;
return context->ResizeTensor(context, output,
TfLiteIntArrayCopy(input->dims));
}
// This is derived from lite/kernels/activations.cc.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input = GetInput(context, node, 0);
TfLiteTensor* output = GetOutput(context, node, 0);
const int elements = NumElements(input);
const float* in = input->data.f;
const float* in_end = in + elements;
float* out = output->data.f;
for (; in < in_end; ++in, ++out) {
*out = std::min(std::max(0.f, *in), 1.f);
}
return kTfLiteOk;
}
} // namespace relu1
TfLiteRegistration* Register_RELU_1() {
static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
relu1::Prepare, relu1::Eval};
return &r;
}
} // namespace custom
} // namespace ops
} // namespace tflite
可以看到,与官方给出的教程一样,关键点是实现 Prepare
和 Eval
这两个函数。我们自己在自定义op的代码时,可以把这个文件当做参考模板。
第3步,新增代码怎么参与编译?
这块需要一些 C++ 大工程开发的知识,Tensorflow 是用 Bazel 作工程编译的,所以关键点在目标文件夹下的 BUILD
文件。
而 BUILD
文件里面这么多的 library,我们的新代码应该编译到哪个 library 中呢?还记得 官方留的自定义op "Relu1" 吗?我们来看看 "Relu1" 是编译到哪个 library。
// 文件:tensorflow/lite/kernels/BUILD
// 行数:278
cc_library(
name = "builtin_op_kernels",
srcs = [
... // 这里有很多其他的源文件
"mfcc.cc",
"relu1.cc",
... // 把新写的代码文件加到这边就可以了
],
hdrs = [
],
copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
visibility = ["//visibility:private"],
deps = [
":activation_functor",
":eigen_support",
":kernel_util",
":lstm_eval",
":op_macros",
":padding",
"//tensorflow/lite:framework",
"//tensorflow/lite:string_util",
"//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/kernels:gemm_support",
"//tensorflow/lite/kernels/internal:audio_utils",
"//tensorflow/lite/kernels/internal:kernel_utils",
"//tensorflow/lite/kernels/internal:optimized",
"//tensorflow/lite/kernels/internal:optimized_base",
"//tensorflow/lite/kernels/internal:quantization_util",
"//tensorflow/lite/kernels/internal:reference_base",
"//tensorflow/lite/kernels/internal:tensor",
"//tensorflow/lite/kernels/internal:tensor_utils",
"@farmhash_archive//:farmhash",
"@flatbuffers",
],
)
嘿嘿嘿,官方可真会偷懒,自定义 op 和内置 op 一起编译到 builtin_op_kernels
库。所以,我们只要把新的代码文件添加到 srcs=[]
里,新的代码就能参与到编译过程中了。
5. 总结
Tensorflow Lite v1.13.2 中,官方偷了个懒,自定义 op 与内置 op 写在同一个位置,且都是编译到 builtin_op_kernels
库。
Tensorflow Lite 的自定义 op 添加方式如下:
**粗体** _斜体_ [链接](http://example.com) `代码` - 列表 > 引用
。你还可以使用@
来通知其他用户。