onnx icon indicating copy to clipboard operation
onnx copied to clipboard

How to make `irfft` in ONNX

Open grazder opened this issue 1 year ago • 12 comments

Ask a Question

Question

Hello! Can you point out how irfft can be used? I found issues and documentation on using rfft, but didn't find anything about irfft.

I found that https://github.com/onnx/onnx/issues/1646 and https://github.com/onnx/onnx/issues/3573 was closed with comment All the other ops from the original list were added at some point.. But I can't find any information related to irfft.

I would be glad to help!

grazder avatar Feb 07 '24 12:02 grazder

related https://github.com/pytorch/pytorch/issues/119360

grazder avatar Feb 07 '24 12:02 grazder

@justinchuby

grazder avatar Feb 07 '24 12:02 grazder

Is this operator best to use?

https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Irfft

grazder avatar Feb 07 '24 13:02 grazder

Thanks for reporting this issue! The pytorch issue you pointed to is the best for tracking. We will make sure it is fixed.

justinchuby avatar Feb 07 '24 17:02 justinchuby

Actually I still got question about irfft. For my case I can't use dynamo_export because lstm's currently not supported there, so I have to manually implement aten::fft_irfft to export my model using torch.onnx.export and torch.onnx.register_custom_op_symbolic

I've tried something like this, but it's wrong implementation

from onnxscript import FLOAT, script
from onnxscript import opset17 as op

@script()
def Irfft(X: FLOAT[481, 2]):
    x = op.Unsqueeze(X, [0])
    x = op.DFT(x, 960, axis=1, inverse=1, onesided=0)
    x = op.Squeeze(x, [0])
    return x

Also I've tried to manually implement irfft using ifft

import torch

class iRFFTModel(nn.Module):
  def __init__(self):
      super().__init__()

  def forward(self, x):
    x_conj = x * torch.tensor([[1, -1]])
    x_conj = torch.flip(x_conj, dims=(0,))[1:-1]
    x = torch.cat((x, x_conj))
    x = torch.fft.ifft(torch.view_as_complex(x)).real
    return x

but it works much slower. Do you have some recommendations how to implement it using onnxscript or using torch jit_utils.GraphContext ops?

grazder avatar Feb 08 '24 06:02 grazder

Is it possible to express as a function of DFT?

xadupre avatar Feb 08 '24 13:02 xadupre

Does setting onesided to 1 work?

justinchuby avatar Feb 08 '24 15:02 justinchuby

import torch
from onnxscript import FLOAT, script
from onnxscript import opset17 as op

@script()
def Irfft(X: FLOAT[481, 2]):
    x = op.Unsqueeze(X, [0])
    x = op.DFT(x, 960, axis=1, inverse=1, onesided=1)
    x = op.Squeeze(x, [0])
    return x

x = torch.randn(481, 2).detach().cpu().numpy()
Irfft(x).shape
[ONNXRuntimeError] : 1 : FAIL : Node () Op (DFT) [ShapeInferenceError] is_onesided and inverse attributes cannot be enabled at the same time

Onesided option only available for rfft, but for inverse=True it doesn't working

grazder avatar Feb 08 '24 17:02 grazder

Thanks, created https://github.com/onnx/onnx/issues/5920. For now, we will update the ONNX Script implementation with help from @titaiwangms

justinchuby avatar Feb 08 '24 17:02 justinchuby

So, am I understand correctly that currently there is no way to create irfft node using DFT and we should wait for fix?

grazder avatar Feb 08 '24 19:02 grazder

There is. I think you / we will just have to figure out how to restore the n. You may also consider https://github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.Irfft like you mentioned and register a symbolic function in PyTorch (taking the g.op GraphContext path)

justinchuby avatar Feb 08 '24 19:02 justinchuby

Yeah, I've tried it already, but currently it is only available for cuda, but I need cpu implementation :c

grazder avatar Feb 08 '24 19:02 grazder