mmcv icon indicating copy to clipboard operation
mmcv copied to clipboard

How to transform custom ops to onnx

Open zgplvyou opened this issue 2 years ago • 5 comments

I try to transform the custom ops to onnx as the nms set. I define the static symbolic method. when the output is single, I got the custom onnx node. however, when the number of output is more than one, the custom node would be optimized. Here is my simple code.

def func1(a, b):
    c = torch.add(a, b)
    d = torch.sub(c, a)
    return [c,d]

class NMSop(torch.autograd.Function):

    @staticmethod
    def forward(ctx, a, thr):
        return func1(a, thr)

    @staticmethod
    def symbolic(g, a, thr):
        x =  g.op('Aten::NMS', a, thr)
        return  x

class CustomNet(nn.Module):
    def __init__(self):
        super(CustomNet, self).__init__()
        self.net = NMSop()

    def forward(self, a, b):
        #import pdb;pdb.set_trace()
        x = NMSop.apply(a, b)
        #x = a.repeat_interleave(torch.tensor([2]), dim=0)
        return x

net = CustomNet()
t = torch.randn(3,4)
s = torch.randn(3,4)
#result = net(t)
#print(result)
torch.onnx.export(net, (t,s), 'test.onnx', verbose=True)

if the func1 return c, I can get the NMS node, if return [c,d], can not. So there are some methods to process multiple output?

zgplvyou avatar Sep 07 '22 12:09 zgplvyou

  1. Output tuple instead of list.
  2. add outputs=<number of outputs> in g.op(...)
  3. make sure you are using torch>=1.7

grimoire avatar Sep 08 '22 01:09 grimoire

  1. Output tuple instead of list.
  2. add outputs=<number of outputs> in g.op(...)
  3. make sure you are using torch>=1.7

NB plus, thank u for your solution, it works.

zgplvyou avatar Sep 08 '22 05:09 zgplvyou

@grimoire by the way, I use the symbolic method to registered customized onnx ops, while the onnx nodes miss the shape information. Is there any method to retain the shape information.

zgplvyou avatar Sep 08 '22 12:09 zgplvyou

Sorry, I do not know how to add shape inference for the symbolic in PyTorch. Maybe you can add a custom schema with this.

grimoire avatar Sep 09 '22 02:09 grimoire

Sorry, I do not know how to add shape inference for the symbolic in PyTorch. Maybe you can add a custom schema with this.

ok, I trace the graph, the shape information is missing in the jit graph, I will continue to find the solution next week. Anyway, happy mid-autumn festival, today is going into garbage time.

zgplvyou avatar Sep 09 '22 02:09 zgplvyou