tvm
tvm copied to clipboard
[bug/feature][relax.frontend.torch] from_exported_program rejects randn.default (blocks repro that stresses advanced-indexing write + tuple output)
Actual behavior
Importing an exported PyTorch program fails early in TVM Relax Torch frontend with:
AssertionError: Unsupported function types ['randn.default']
The model is intentionally small and uses advanced indexing write (two index tensors on the same dims) and returns a tuple (y, L) to exercise relax::Tuple creation. However, the importer aborts first due to missing coverage for aten::randn.default, so we cannot reach the later code path.
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
- Pytorch: (2.8.0)
Steps to reproduce
# Purpose: Trigger importer on advanced indexing (two index tensors) + tuple outputs.
# Current blocker: Unsupported function types ['randn.default'].
import torch
import torch.nn as nn
import torch.nn.functional as F
def main():
print("==== Versions ====")
import tvm
print("PyTorch:", torch.__version__)
print("TVM :", tvm.__version__)
torch.manual_seed(0)
N = 5
class M(nn.Module):
def forward(self, x):
L = torch.zeros(N, N, dtype=x.dtype, device=x.device)
idx = torch.arange(N, device=x.device)
v = torch.randn(N, device=x.device) # <-- aten::randn.default
v = F.elu(v) + 1.0 + 1e-8 # keep values > 0
L[idx, idx] = v # advanced indexing write
y = x + 1
return y, L # tuple output -> relax::Tuple
m = M().eval()
ex_in = torch.randn(2, N)
# 1) export OK
from torch.export import export as torch_export
ep = torch_export(m, (ex_in,))
print("torch.export: OK")
# 2) TVM import -> fails on unsupported randn
from tvm.relax.frontend.torch import from_exported_program
print("about to call from_exported_program(ep)...")
mod = from_exported_program(ep) # <-- aborts with Unsupported function types ['randn.default']
print("from_exported_program: OK (if you see this, the randn support may have landed)")
if __name__ == "__main__":
main()
Output
==== Versions ====
PyTorch: 2.8.0a0+gitba56102
TVM : 0.21.0
torch.export: OK
about to call from_exported_program(ep)...
Traceback (most recent call last):
...
File ".../tvm/relax/frontend/torch/base_fx_graph_translator.py", line 116, in _check_unsupported_func_type
assert not missing_func_types, f"Unsupported function types {missing_func_types}"
AssertionError: Unsupported function types ['randn.default']
Triage
- needs-triage
- bug