burn
burn copied to clipboard
Add a way to implement custom function for operators not supported in ONNX format
Feature description
Some operators are not supported by the ONNX format and they can be implemented as custom functions only. These functions are then identified during parsing and glued in the model. For example, fft operators are implemented through custom functions. In pytorch, this is the current process https://pytorch.org/docs/stable/onnx.html#c-operators to export C++ operators.
from torch.onnx import symbolic_helper
# Define custom symbolic function
@symbolic_helper.parse_args("v", "v", "f", "i")
def symbolic_foo_forward(g, input1, input2, attr1, attr2):
return g.op("custom_domain::Foo", input1, input2, attr1_f=attr1, attr2_i=attr2)
# Register custom symbolic function
torch.onnx.register_custom_op_symbolic("custom_ops::foo_forward", symbolic_foo_forward, 9)
class FooModel(torch.nn.Module):
def __init__(self, attr1, attr2):
super().__init__()
self.attr1 = attr1
self.attr2 = attr2
def forward(self, input1, input2):
# Calling custom op
return torch.ops.custom_ops.foo_forward(input1, input2, self.attr1, self.attr2)
model = FooModel(attr1, attr2)
torch.onnx.export(
model,
(example_input1, example_input1),
"model.onnx",
# only needed if you want to specify an opset version > 1.
custom_opsets={"custom_domain": 2}
)
It would be helpful to implement the same approach in burn-import for Rust language and test whether the behavior is correct using the model.onnx generated by the Python script inserted above.
A list of some initial operators could be:
- [ ]
rfft - [ ]
irrft - [ ]
rfftn - [ ]
irfftn - [ ]
fft - [ ]
ifft - [ ]
fftn - [ ]
ifftn
Feature motivation
Some custom operators are very pervasive, so having a way to retrieve and define them in burn could help to convert many networks to different backends.
Also see: https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html
For example, fft operators are implemented through custom functions
isn't the fft itself in the list of supported onnx ops as DFT? FFT is just just the (family of) algorithm(s) for computing the DFT.
Also, would the goal be to support this directly in burn-import, or to add the option to support custom ops outside of those defined in burn-import
For example, fft operators are implemented through custom functions
isn't the fft itself in the list of supported onnx ops as DFT? FFT is just just the (family of) algorithm(s) for computing the DFT.
Also, would the goal be to support this directly in burn-import, or to add the option to support custom ops outside of those defined in burn-import
At the time fft wasn't supported. We still have a few unsupported complex functions that will be a while to support. It is better if users can plug in their own functions.
I'm looking into how this would work with the NodeType used in burn-import. Looking into enumstring a bit, it seems like there is a variant default, we might be able to use to capture whatever burn-import doesn't directly support yet.
something like
pub enum NodeType {
...
#[strum(default)]
Custom(String)
And then provide a way of adding custom functions if the op isn't directly implemented in burn-import. I'll add more info on how in a bit.