码迷,mamicode.com
首页 > 其他好文 > 详细

PyTorch自定义OP导出ONNX

时间:2020-06-04 14:12:08      阅读:193      评论:0      收藏:0      [点我收藏+]

标签:第一个   for   method   int   port   eth   input   官方文档   lan   

根据PyTorch的官方文档,需要用Function封装一下,为了能够导出ONNX需要加一个symbolic静态方法:

class relu5_func(Function):
    @staticmethod
    def forward(ctx, input):
        return relu5_cuda.relu5(input)
    @staticmethod
    def symbolic(g, *inputs):
        return g.op("Relu5", inputs[0], myattr_f=1.0) 
        # 这里第一个参数"Relu5"表示ONNX输出命名
        # myattr可以随便取,表示一个属性名,_f表示是一个float类型
relu5 = relu5_func.apply

定义好后,用以下代码测试

import torch
import torch.nn as nn
import relu5_cuda
import onnx
from torch.autograd import Function
from torch.autograd.function import once_differentiable
import netron

class TinyNet(nn.Module):
    def __init__(self):
        super(TinyNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = x.view(-1)
        x = relu5(x)
        return x

net = TinyNet().cuda()
ipt = torch.ones(2,3,12,12).cuda()
torch.onnx.export(net, (ipt,), ‘tinynet.onnx‘)
print(onnx.load(‘tinynet.onnx‘))
netron.start(‘tinynet.onnx‘)

PyTorch自定义OP导出ONNX

标签:第一个   for   method   int   port   eth   input   官方文档   lan   

原文地址:https://www.cnblogs.com/xytpai/p/13042667.html

(0)
(0)
   
举报
评论 一句话评论(0
登录后才能评论!
© 2014 mamicode.com 版权所有  联系我们:gaon5@hotmail.com
迷上了代码!