error while converting complex models to onnx model (caused by view_as_complex)
🐛 Bug
While exporting to onnx some of the models (with complex operation). There is error caused by no support of complex casting in the onnx ops set torch.view_as_complex(input))
To Reproduce
torch.onnx.export(model_dccrn, input_random, 'model_dccrn.onnx', verbose=True, opset_version=11)
Steps to reproduce the behavior (code sample and stack trace):
~/lib/miniconda3/lib/python3.9/site-packages/torch/onnx/symbolic_registry.py in get_registered_op(opname, domain, version)
114 else:
115 msg += "Please feel free to request support or submit a pull request on PyTorch GitHub."
--> 116 raise RuntimeError(msg)
117 return _registry[(domain, version)][opname]
RuntimeError: Exporting the operator view_as_complex to ONNX opset version 12 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub.
Expected behavior
The convertion should proceed without errors and end with proper onnx model.
Environment
Package versions
Run asteroid-versions and paste the output here:
Asteroid 0.5.1
PyTorch 1.9.0
PyTorch-Lightning 1.3.8
Additional info
I know it is not set case in pytorch-onnx ops set
ONNX current operations:
- https://github.com/onnx/onnx/blob/master/docs/Operators.md
- in part for Cast
-
Casting to complex is not supported.
issues with errors:
- error with support view_as_complex for converting to onnx model
- view_as_complex https://pytorch.org/docs/stable/generated/torch.view_as_complex.html
- https://github.com/onnx/onnx/issues/3173
- https://github.com/pytorch/pytorch/issues/49793
However, we can propose a wrapper that is covering this convertion in such a way that onnx model will be created properly.
view_as_complex is implemented in ATen library
- https://github.com/pytorch/pytorch/blob/30e48bbeae545c3292c2ab3fed0cb2dba4a92fed/aten/src/ATen/native/ComplexHelper.h#L70
const auto new_strides = computeStrideForViewAsComplex(self.strides()); const auto complex_type = c10::toComplexType(self.scalar_type()); view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
After investigation, the error while converting is located at OnReIm
def forward(self, x): return torch_complex_from_reim(self.re_module(x.real), self.im_module(x.imag))
which calls ` def torch_complex_from_magphase(mag, phase): return torch.view_as_complex( torch.stack((mag * torch.cos(phase), mag * torch.sin(phase)), dim=-1) )
def torch_complex_from_reim(re, im): return torch.view_as_complex(torch.stack([re, im], dim=-1)) `
also in the astroid_filterbank there is operation
def to_torch_complex(tensor, dim: int = -2): return torch.view_as_complex(to_torchaudio(tensor, dim=dim))
https://github.com/asteroid-team/asteroid-filterbanks/blob/8a3d13fb0e495772bc9d1deac3327affe2833e10/asteroid_filterbanks/transforms.py#L327
The problem is that ONNX doesn't support PyTorch's complex numbers. There isn't a lot we can do except for creating a facade for PyTorch's complex numbers that its based on real numbers. It's not a lot of work to implement, in fact I've implemented it multiple times, but I'm not sure if we should include that code in Asteroid.
I guess for simple operations, it's possible, but when solve and eigenvalue decompositions are computed, having the facade is more complicated, right?
I don't really know what we should do about that.
@jonashaag
It's not a lot of work to implement. [...] It's not a lot of work to implement [...]
can you explain your approach to that ? Maybe show code snipset ?
Are you thinking of dual path or double input size of standard tensors, one for mag and other for phase, going into each module ?
Unfortunately I don't have access to the code for a few days.
The approach I've taken is as follows:
- Move all
t.abs(),t.angle(),t.view_as_complex()etc calls on complex tensors tocomplex_nn.abs(t)etc (and of course implement it there). - Check that the code still works with those trivial code moves.
- Change the implementation of
view_as_complexto return your preferred complex representation. See module docstring incomplex_nn. - You'll also have to change some code that deals with shapes, like
.permute(),.transpose(), slices. Unfortunately that is scattered all over the place in the models so you'll have to make those changes step by step until it works.
Thanks a lot, I am debugging some of the models with asteroid complex representation and there are still some errors. I think all of the error places are in complex_nn.