标签:shared 有用 类构造 enqueue queue public mpi blank 访问
参考: http://www.tensorfly.cn/tfdoc/how_tos/adding_an_op.html
添加新的OP需要3步(下述所有代码在here):
1. 定义 Op 的接口
// 1. 定义 Op 的接口 // REGISTER_OP()向 TensorFlow 系统注册来定义 Op 的接口,该OP就是HorovodAllreduceOp. // 在注册时, 指定 Op 的名称: REGISTER_OP("HorovodAllreduce") // 输入(类型和名称): Input("tensor: T") // 输出(类型和名称): Output("sum: T") // 和所需要任何 属性的文档说明Doc(R"doc(...)doc"); // // 该 Op 接受一个 T 类型 tensor 作为输入, T 类型可以是{int32, int64, float32, float64} // 输出一个 T 类型 tensor sum,sum是在所有的MPI进程中求和 REGISTER_OP("HorovodAllreduce") .Attr("T: {int32, int64, float32, float64}") .Input("tensor: T") .Output("sum: T") .SetShapeFn([](shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); }) .Doc(R"doc( Perform an MPI Allreduce on a tensor. All other processes that do a reduction on a tensor with the same name must have the same dimension for that tensor. Tensors are reduced with other tensors that have the same node name for the allreduce. Arguments tensor: A tensor to reduce. Output sum: A tensor with the same shape as `tensor`, summed across all MPI processes. )doc");
2. 为 Op 实现 kernel
// 2. 为 Op 实现 kernel。 // 在定义接口之后, 每一个实现称之为一个 "kernel",提供一个或多个 Op 的实现,即可以存在多个 kernel。 // 为这些 kernel 的每一个创建一个对应的类, 继承 AsyncOpKernel, 覆盖 ComputeAsync 方法。 // ComputeAsync 方法提供一个类型为 OpKernelContext* 的参数 context, 用于访问一些有用的信息, 例如输入和输出的 tensor。 class HorovodAllreduceOp : public AsyncOpKernel { public: // 防止类构造函数的隐式自动转换,只能显示调用该构造函数 explicit HorovodAllreduceOp(OpKernelConstruction* context) : AsyncOpKernel(context) {} // 重写ComputeAsync()方法 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { OP_REQUIRES_OK_ASYNC(context, ConvertStatus(common::CheckInitialized()), done); auto node_name = name(); auto device = GetDeviceID(context); auto tensor = context->input(0); Tensor* output; OP_REQUIRES_OK_ASYNC( context, context->allocate_output(0, tensor.shape(), &output), done); // ReadyEvent makes sure input tensor is ready, and output is allocated. // shared_ptr 是一个标准的共享所有权的智能指针, 允许多个指针指向同一个对象 auto ready_event = std::shared_ptr<common::ReadyEvent>(RecordReadyEvent(context)); // 模板函数 std::make_shared 可以返回一个指定类型的 std::shared_ptr auto hvd_context = std::make_shared<TFOpContext>(context); auto hvd_tensor = std::make_shared<TFTensor>(tensor); auto hvd_output = std::make_shared<TFTensor>(*output); // 将张量的Allreduce操作OP加入队列,加入谁的队列?? auto enqueue_result = EnqueueTensorAllreduce( hvd_context, hvd_tensor, hvd_output, ready_event, node_name, device, [context, done](const common::Status& status) { context->SetStatus(ConvertStatus(status)); done(); }); OP_REQUIRES_OK_ASYNC(context, ConvertStatus(enqueue_result), done); } };
3. 注册OP到 TensorFlow 系统
// 3. 注册OP到 TensorFlow 系统 // 注册时可以指定该 kernel 运行时的多个约束条件. 例如可以指定一个 kernel 在 CPU 上运行, 另一个在 GPU 上运行 REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_CPU), HorovodAllreduceOp); // 如果执行了GPU #if HOROVOD_GPU_ALLREDUCE REGISTER_KERNEL_BUILDER(Name("HorovodAllreduce").Device(DEVICE_GPU), HorovodAllreduceOp); #endif
以horovd的HorovodAllreduceOp为例,学习如何在tensorflow上添加一个新的操作OP
标签:shared 有用 类构造 enqueue queue public mpi blank 访问
原文地址:https://www.cnblogs.com/lixiaolun/p/9163431.html