tvm
tvm copied to clipboard
[bug][relax.frontend.torch] FFI segfault in tvm::relax::Tuple::Tuple when importing torch.export graph with 4D advanced-indexing write (aten.index_put_) and tuple outputs
Summary
Importing a torch.exported program into TVM Relax triggers a segmentation fault inside FFI during construction of a Relax Tuple. The minimal model performs a 4D advanced indexing write using two integer index tensors on the last two dims (L[..., idx, idx] = ...) and returns a Python tuple of tensors (x[..., :1], L). The exported graph is free of RNG ops (no randn), so the crash appears related to the combination of aten.index_put_ lowering and tuple output construction.
Actual behavior
[1] torch.export ...
=== Exported ops ===
... (as above)
[2] tvm.relax.frontend.torch.from_exported_program ...
!!!!!!! TVM FFI encountered a Segfault !!!!!!!
...
tvm::relax::Tuple::Tuple(tvm::ffi::Array<tvm::RelaxExpr, void>, tvm::Span) [clone .cold]
...
Segmentation fault (core dumped)
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
- Pytorch: (2.8.0)
Steps to reproduce
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" # avoid GPU warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
class M4D(nn.Module):
def forward(self, x):
B, K, N = 2, 3, 5
L = x.new_zeros(B, K, N, N) # tensor construct only; no randomness
idx = torch.arange(N, device=x.device)
# key trigger: gather diagonal, apply smooth monotonic transform, scatter back
diag = L[..., idx, idx] # shape: [B, K, N]
diag = F.elu(diag) + 1.0 + 1e-8 # avoid all-zero; any smooth transform works
L[..., idx, idx] = diag # advanced indexing write (two int index tensors)
# key trigger: return a Python-level tuple (two tensors)
return x[..., :1], L
if __name__ == "__main__":
torch.manual_seed(0)
m = M4D().eval()
ex_in = torch.zeros(2, 3, 5) # any input; ensures no randn exported
print("[1] torch.export ...")
ep = torch_export(m, (ex_in,))
# sanity: list exported ops
try:
print("=== Exported ops ===")
for n in ep.graph.nodes:
print(getattr(n, "op", None), getattr(n, "target", None))
except Exception:
pass
print("[2] tvm.relax.frontend.torch.from_exported_program ...")
mod = from_exported_program(ep) # <-- segfaults inside FFI Tuple construction
print("[OK] Converted without segfault (if you see this, env may differ)")
Triage
- needs-triage
- bug