tvm
tvm copied to clipboard
[feature][relax.frontend.torch] Missing coverage for STFT+RNN pipeline: 'rnn_tanh.input', 'real.default', 'imag.default', 'unfold.default', 'fft_fft.default' in from_exported_program
Summary
Importing a torch.exported STFT + RNN toy model fails in the TVM Relax Torch frontend with:
AssertionError: Unsupported function types ['rnn_tanh.input', 'real.default', 'unfold.default', 'imag.default', 'fft_fft.default']
This highlights several missing op coverages common in audio pipelines: 1D framing (unfold), FFT (fft_fft), complex tensor accessors (real/imag), and fused RNN (rnn_tanh.input).
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
- Pytorch: (2.7.1)
Steps to reproduce
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_input(batch=1, length=4096, device="cpu", dtype=torch.float32):
return torch.randn(batch, length, device=device, dtype=dtype)
class MiniSTFTRNN(nn.Module):
def __init__(self, win_len=320, n_fft=512, hop=160):
super().__init__()
self.register_buffer("window", torch.hann_window(win_len))
self.win_len = win_len
self.n_fft = n_fft
self.hop = hop
# use per-frame FFT spectrum as features
self.rnn = nn.RNN(input_size=n_fft, hidden_size=8, num_layers=1,
batch_first=True, nonlinearity="tanh")
self.fc = nn.Linear(8, 4)
def forward(self, x):
# x: (B, L)
pad_tail = (0, self.n_fft - (x.shape[-1] % self.hop)) if (x.shape[-1] % self.hop) != 0 else (0, 0)
x = F.pad(x, pad_tail, mode="constant") # align for framing
frames = x.unfold(-1, self.win_len, self.hop) # (B, T, win_len) -> aten::unfold
frames = frames * self.window # windowing
spec = torch.fft.fft(frames, n=self.n_fft) # aten::fft_fft
real = spec.real # aten::real.default
imag = spec.imag # aten::imag.default
mag = torch.sqrt(real * real + imag * imag) # (B, T, n_fft)
out, _ = self.rnn(mag) # aten::rnn_tanh.input
return self.fc(out[:, -1, :]) # (B, 4)
def main():
import numpy as np
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
torch.manual_seed(0); np.random.seed(0)
model = MiniSTFTRNN().eval()
inp = get_input(batch=1, length=4096)
with torch.inference_mode():
_ = model(inp) # sanity
ep = torch_export(model, (inp,))
mod = from_exported_program(ep) # <- raises assertion
if __name__ == "__main__":
main()
Output
Traceback (most recent call last):
...
File ".../base_fx_graph_translator.py", line 116, in _check_unsupported_func_type
assert not missing_func_types, f"Unsupported function types {missing_func_types}"
AssertionError: Unsupported function types ['rnn_tanh.input', 'real.default', 'unfold.default', 'imag.default', 'fft_fft.default']
Triage
- needs-triage
- bug