tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[feature][relax.frontend.torch] Missing coverage for STFT+RNN pipeline: 'rnn_tanh.input', 'real.default', 'imag.default', 'unfold.default', 'fft_fft.default' in from_exported_program

Open tinywisdom opened this issue 2 months ago • 0 comments

Summary

Importing a torch.exported STFT + RNN toy model fails in the TVM Relax Torch frontend with:

AssertionError: Unsupported function types ['rnn_tanh.input', 'real.default', 'unfold.default', 'imag.default', 'fft_fft.default']

This highlights several missing op coverages common in audio pipelines: 1D framing (unfold), FFT (fft_fft), complex tensor accessors (real/imag), and fused RNN (rnn_tanh.input).

Environment

  • OS: (Ubuntu 22.04.4 LTS (x86_64))
  • TVM version: (release v0.21.0)
  • Python: (3.10.16)
  • LLVM: (17.0.6)
  • Pytorch: (2.7.1)

Steps to reproduce

import torch
import torch.nn as nn
import torch.nn.functional as F

def get_input(batch=1, length=4096, device="cpu", dtype=torch.float32):
    return torch.randn(batch, length, device=device, dtype=dtype)

class MiniSTFTRNN(nn.Module):
    def __init__(self, win_len=320, n_fft=512, hop=160):
        super().__init__()
        self.register_buffer("window", torch.hann_window(win_len))
        self.win_len = win_len
        self.n_fft = n_fft
        self.hop = hop
        # use per-frame FFT spectrum as features
        self.rnn = nn.RNN(input_size=n_fft, hidden_size=8, num_layers=1,
                          batch_first=True, nonlinearity="tanh")
        self.fc = nn.Linear(8, 4)

    def forward(self, x):
        # x: (B, L)
        pad_tail = (0, self.n_fft - (x.shape[-1] % self.hop)) if (x.shape[-1] % self.hop) != 0 else (0, 0)
        x = F.pad(x, pad_tail, mode="constant")                # align for framing
        frames = x.unfold(-1, self.win_len, self.hop)          # (B, T, win_len) -> aten::unfold
        frames = frames * self.window                           # windowing
        spec = torch.fft.fft(frames, n=self.n_fft)              # aten::fft_fft
        real = spec.real                                        # aten::real.default
        imag = spec.imag                                        # aten::imag.default
        mag = torch.sqrt(real * real + imag * imag)             # (B, T, n_fft)
        out, _ = self.rnn(mag)                                  # aten::rnn_tanh.input
        return self.fc(out[:, -1, :])                           # (B, 4)

def main():
    import numpy as np
    from torch.export import export as torch_export
    from tvm.relax.frontend.torch import from_exported_program

    torch.manual_seed(0); np.random.seed(0)
    model = MiniSTFTRNN().eval()
    inp = get_input(batch=1, length=4096)
    with torch.inference_mode():
        _ = model(inp)  # sanity

    ep = torch_export(model, (inp,))
    mod = from_exported_program(ep)  # <- raises assertion

if __name__ == "__main__":
    main()

Output

Traceback (most recent call last):
  ...
  File ".../base_fx_graph_translator.py", line 116, in _check_unsupported_func_type
    assert not missing_func_types, f"Unsupported function types {missing_func_types}"
AssertionError: Unsupported function types ['rnn_tanh.input', 'real.default', 'unfold.default', 'imag.default', 'fft_fft.default']

Triage

  • needs-triage
  • bug

tinywisdom avatar Oct 08 '25 10:10 tinywisdom