burn icon indicating copy to clipboard operation
burn copied to clipboard

Add a way to implement custom function for operators not supported in ONNX format

Open Luni-4 opened this issue 2 years ago • 4 comments

Feature description

Some operators are not supported by the ONNX format and they can be implemented as custom functions only. These functions are then identified during parsing and glued in the model. For example, fft operators are implemented through custom functions. In pytorch, this is the current process https://pytorch.org/docs/stable/onnx.html#c-operators to export C++ operators.

from torch.onnx import symbolic_helper

# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
    return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)


# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)


class FooModel(torch.nn.Module):
    def __init__(self, attr1, attr2):
        super().__init__()
        self.attr1 = attr1
        self.attr2 = attr2

    def forward(self, input1, input2):
        # Calling custom op
        return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)


model = FooModel(attr1, attr2)
torch.onnx.export(
    model,
    (example_input1, example_input1),
    "model.onnx",
    # only needed if you want to specify an opset version > 1.
    custom_opsets={"custom_domain": 2}
)

It would be helpful to implement the same approach in burn-import for Rust language and test whether the behavior is correct using the model.onnx generated by the Python script inserted above.

A list of some initial operators could be:

  • [ ] rfft
  • [ ] irrft
  • [ ] rfftn
  • [ ] irfftn
  • [ ] fft
  • [ ] ifft
  • [ ] fftn
  • [ ] ifftn

Feature motivation

Some custom operators are very pervasive, so having a way to retrieve and define them in burn could help to convert many networks to different backends.

Luni-4 avatar Sep 06 '23 17:09 Luni-4

Also see: https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html

antimora avatar Nov 28 '23 04:11 antimora

For example, fft operators are implemented through custom functions

isn't the fft itself in the list of supported onnx ops as DFT? FFT is just just the (family of) algorithm(s) for computing the DFT.

Also, would the goal be to support this directly in burn-import, or to add the option to support custom ops outside of those defined in burn-import

skewballfox avatar Jun 06 '24 18:06 skewballfox

For example, fft operators are implemented through custom functions

isn't the fft itself in the list of supported onnx ops as DFT? FFT is just just the (family of) algorithm(s) for computing the DFT.

Also, would the goal be to support this directly in burn-import, or to add the option to support custom ops outside of those defined in burn-import

At the time fft wasn't supported. We still have a few unsupported complex functions that will be a while to support. It is better if users can plug in their own functions.

antimora avatar Jun 06 '24 19:06 antimora

I'm looking into how this would work with the NodeType used in burn-import. Looking into enumstring a bit, it seems like there is a variant default, we might be able to use to capture whatever burn-import doesn't directly support yet.

something like

pub enum NodeType {
...
    #[strum(default)]
    Custom(String)

And then provide a way of adding custom functions if the op isn't directly implemented in burn-import. I'll add more info on how in a bit.

skewballfox avatar Jun 08 '24 16:06 skewballfox