wonnx
wonnx copied to clipboard
Fuse mapping ops
Is your feature request related to a problem? Please describe.
When there are consecutive mapping operations (Neg, Relu, etc.) we should not execute these serially each in their own shader - instead we should just write a shader that does neg(relu(input))
in one go (if at least the intermediate result from Relu
in this example is not used elsewhere).
Describe the solution you'd like
Fusing should happen in the optimizer. We can introduce a custom op type wonnx.Map
that takes one input and an attribute describing the functions to perform consecutively (in the above example it would contain Relu,Neg
).
To also accomodate binary functions (Add, Sub, etc.) we might even allow an arbitrary number of inputs and have the attribute describe (in RPN) the desired operations, e.g. neg(relu(add(a, sub(b, c))))
would have three inputs (b
, c
, a
in that order) and the attribute could contain Push, Push, Sub, Push, Add, Relu, Neg
. The compiler can then simply write out the WGSL corresponding to this.
Describe alternatives you've considered
Fusing would also be possible at the shape inference stage.
We should check if the current ConvRelu
optimization (which fuses Conv
and Relu
works properly if the output from Conv
is also used further on.
Additional context