标签:roo code ring tip uil ons 通过 flow ice
参考:https://tensorflow.juejin.im/extend/adding_an_op.html
https://zhuanlan.zhihu.com/p/34168765
为了加入一个定制操作,你需要:
tf.test.compute_gradient_error
1. 定义接口:
#include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" using namespace tensorflow; REGISTER_OP("ZeroOut") .Input("to_zero: int32") .Output("zeroed: int32") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); });
关于命名的备注:操作名称必须首字母大写,而且不能和库中已经注册的其它操作重名。
2. 实现操作的内核
定义接口后,接下来就需要为此操作提供一个或多个内核实现了。
为了实现这些内核,创建一个继承自 OpKernel
的类,并重载 Compute
方法。
Compute
方法有一个类型为 OpKernelContext*
的参数 context
,从中可以访问输入和输出张量等有用的信息。
将你的内核加到上面创建的文件中。这个内核的代码形如:
#include "tensorflow/core/framework/op_kernel.h" using namespace tensorflow; class ZeroOutOp : public OpKernel { public: explicit ZeroOutOp(OpKernelConstruction* context) : OpKernel(context) {} void Compute(OpKernelContext* context) override { // 得到输入张量 const Tensor& input_tensor = context->input(0); auto input = input_tensor.flat<int32>(); // 创建输出张量 Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); auto output_flat = output_tensor->flat<int32>(); // 除第一个元素外,输出张量的其它所有元素都设置为 0 const int N = input.size(); for (int i = 1; i < N; i++) { output_flat(i) = 0; } // 如果可能的话,保留第一个输入值 if (N > 0) output_flat(0) = input(0); } };
给 ZeroOut
操作加上约束条件:
REGISTER_KERNEL_BUILDER(Name("ZeroOut").Device(DEVICE_CPU), ZeroOutOp);
这里注册的操作名是ZeroOut,通过上面的语句和ZeroOutOp对应吧,
下面对前面的示例做个总结,一个操作注册可以指定多个输入输出:
REGISTER_OP("MultipleInsAndOuts") .Input("y: int32") .Input("z: float") .Output("a: string") .Output("b: int32");
标签:roo code ring tip uil ons 通过 flow ice
原文地址:https://www.cnblogs.com/inshallah/p/12084043.html