tvm
tvm copied to clipboard
[Bug] [relax][torch] from_exported_program segfault with exported MHA using eq(0)/expand mask + in-place masked_fill_ (get_attr lifting warning from PyTorch)
Expected behavior
A tiny Transformer-like block exported via torch.export crashes TVM when importing with tvm.relax.frontend.torch.from_exported_program(ep). Before the crash, PyTorch emits warnings that torch.export inserted a get_attr node without a backing submodule/parameter/buffer. TVM then segfaults in tvm::relax::Tuple::Tuple(...)/FFI path while translating the exported program.
Actual behavior
torch.export succeeds but prints the above get_attr lifting warnings.
Immediately after, tvm.relax.frontend.torch.from_exported_program(ep) triggers an FFI segfault. (In my run it shows an FFI backtrace ending in tvm::relax::Tuple::Tuple(...) / TVM FFI traceback.)
!!!!!!! TVM FFI encountered a Segfault !!!!!!!
... tvm::relax::Tuple::Tuple(...) ...
Segmentation fault (core dumped)
Environment
- OS: (Ubuntu 22.04.4 LTS (x86_64))
- TVM version: (release v0.21.0)
- Python: (3.10.16)
- LLVM: (17.0.6)
- Pytorch: (2.7.1)
Steps to reproduce
# mini_repro_export_tvm_segfault.py
import math
import torch
import torch.nn as nn
def get_attn_pad_mask(seq_q, seq_k):
B, Lq = seq_q.size()
B2, Lk = seq_k.size()
assert B == B2
pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # (B,1,Lk)
return pad_attn_mask.expand(B, Lq, Lk) # (B,Lq,Lk)
class TinyMHA(nn.Module):
def __init__(self, d_model=64, d_k=16, n_heads=4, dropout=0.1):
super().__init__()
self.h, self.dk = n_heads, d_k
self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
self.W_V = nn.Linear(d_model, d_k * n_heads, bias=False)
self.proj = nn.Linear(d_k * n_heads, d_model, bias=False)
self.ln = nn.LayerNorm(d_model)
self.drop = nn.Dropout(dropout)
def forward(self, x, attn_mask): # x: (B,L,dm), attn_mask: (B,L,L)
B, L, _ = x.shape
q = self.W_Q(x).view(B, L, self.h, self.dk).transpose(1, 2) # (B,H,L,dk)
k = self.W_K(x).view(B, L, self.h, self.dk).transpose(1, 2)
v = self.W_V(x).view(B, L, self.h, self.dk).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk) # (B,H,L,L)
# In-place masked_fill_ with broadcasted mask coming from eq(0)+expand
scores.masked_fill_(attn_mask.unsqueeze(1), -1e9)
attn = torch.softmax(scores, dim=-1)
ctx = torch.matmul(attn, v).transpose(1, 2).reshape(B, L, self.h * self.dk)
out = self.drop(self.proj(ctx))
return self.ln(out + x)
class MiniModel(nn.Module):
def __init__(self, vocab=10000, d_model=64):
super().__init__()
self.emb = nn.Embedding(vocab, d_model)
self.mha = TinyMHA(d_model=d_model, d_k=16, n_heads=4, dropout=0.1)
self.proj = nn.Linear(d_model, vocab, bias=False)
def forward(self, enc_inputs, dec_inputs_unused=None):
x = self.emb(enc_inputs) # (B,L,dm)
mask = get_attn_pad_mask(enc_inputs, enc_inputs) # (B,L,L)
y = self.mha(x, mask) # (B,L,dm)
logits = self.proj(y) # (B,L,V)
return logits.reshape(-1, logits.size(-1)) # (B*L, V)
def my_model_function(): return MiniModel()
def GetInput():
enc = torch.randint(0, 10000, (2, 5))
enc[0, 0] = 0 # ensure eq(0) path is taken
dec = torch.randint(0, 10000, (2, 5))
return (enc, dec)
import numpy as np
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program
def trigger_known_bugs(model=None):
if model is None:
model = my_model_function()
torch.manual_seed(42); np.random.seed(42)
model.eval()
args = GetInput()
ep = torch_export(model, args) # Emits get_attr warnings (see below)
mod = from_exported_program(ep) # <-- TVM segfaults here in my env
print(mod)
if __name__ == "__main__":
import os
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "6,7")
trigger_known_bugs()
Triage
- needs-triage
- bug
cc @junrushao @shingjan