tvm icon indicating copy to clipboard operation
tvm copied to clipboard

[Bug] [relax][torch] from_exported_program segfault with exported MHA using eq(0)/expand mask + in-place masked_fill_ (get_attr lifting warning from PyTorch)

Open tinywisdom opened this issue 2 months ago • 0 comments

Expected behavior

A tiny Transformer-like block exported via torch.export crashes TVM when importing with tvm.relax.frontend.torch.from_exported_program(ep). Before the crash, PyTorch emits warnings that torch.export inserted a get_attr node without a backing submodule/parameter/buffer. TVM then segfaults in tvm::relax::Tuple::Tuple(...)/FFI path while translating the exported program.

Actual behavior

torch.export succeeds but prints the above get_attr lifting warnings.

Immediately after, tvm.relax.frontend.torch.from_exported_program(ep) triggers an FFI segfault. (In my run it shows an FFI backtrace ending in tvm::relax::Tuple::Tuple(...) / TVM FFI traceback.)

!!!!!!! TVM FFI encountered a Segfault !!!!!!! 
... tvm::relax::Tuple::Tuple(...) ...
Segmentation fault (core dumped)

Environment

  • OS: (Ubuntu 22.04.4 LTS (x86_64))
  • TVM version: (release v0.21.0)
  • Python: (3.10.16)
  • LLVM: (17.0.6)
  • Pytorch: (2.7.1)

Steps to reproduce

# mini_repro_export_tvm_segfault.py
import math
import torch
import torch.nn as nn

def get_attn_pad_mask(seq_q, seq_k):
    B, Lq = seq_q.size()
    B2, Lk = seq_k.size()
    assert B == B2
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # (B,1,Lk)
    return pad_attn_mask.expand(B, Lq, Lk)         # (B,Lq,Lk)

class TinyMHA(nn.Module):
    def __init__(self, d_model=64, d_k=16, n_heads=4, dropout=0.1):
        super().__init__()
        self.h, self.dk = n_heads, d_k
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.proj = nn.Linear(d_k * n_heads, d_model, bias=False)
        self.ln = nn.LayerNorm(d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, attn_mask):  # x: (B,L,dm), attn_mask: (B,L,L)
        B, L, _ = x.shape
        q = self.W_Q(x).view(B, L, self.h, self.dk).transpose(1, 2)  # (B,H,L,dk)
        k = self.W_K(x).view(B, L, self.h, self.dk).transpose(1, 2)
        v = self.W_V(x).view(B, L, self.h, self.dk).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.dk)  # (B,H,L,L)

        # In-place masked_fill_ with broadcasted mask coming from eq(0)+expand
        scores.masked_fill_(attn_mask.unsqueeze(1), -1e9)

        attn = torch.softmax(scores, dim=-1)
        ctx = torch.matmul(attn, v).transpose(1, 2).reshape(B, L, self.h * self.dk)
        out = self.drop(self.proj(ctx))
        return self.ln(out + x)

class MiniModel(nn.Module):
    def __init__(self, vocab=10000, d_model=64):
        super().__init__()
        self.emb = nn.Embedding(vocab, d_model)
        self.mha = TinyMHA(d_model=d_model, d_k=16, n_heads=4, dropout=0.1)
        self.proj = nn.Linear(d_model, vocab, bias=False)

    def forward(self, enc_inputs, dec_inputs_unused=None):
        x = self.emb(enc_inputs)                          # (B,L,dm)
        mask = get_attn_pad_mask(enc_inputs, enc_inputs)  # (B,L,L)
        y = self.mha(x, mask)                             # (B,L,dm)
        logits = self.proj(y)                             # (B,L,V)
        return logits.reshape(-1, logits.size(-1))        # (B*L, V)

def my_model_function(): return MiniModel()
def GetInput():
    enc = torch.randint(0, 10000, (2, 5))
    enc[0, 0] = 0  # ensure eq(0) path is taken
    dec = torch.randint(0, 10000, (2, 5))
    return (enc, dec)

import numpy as np
from torch.export import export as torch_export
from tvm.relax.frontend.torch import from_exported_program

def trigger_known_bugs(model=None):
    if model is None:
        model = my_model_function()
    torch.manual_seed(42); np.random.seed(42)
    model.eval()
    args = GetInput()

    ep = torch_export(model, args)          # Emits get_attr warnings (see below)
    mod = from_exported_program(ep)         # <-- TVM segfaults here in my env
    print(mod)                              

if __name__ == "__main__":
    import os
    os.environ.setdefault("CUDA_VISIBLE_DEVICES", "6,7")
    trigger_known_bugs()

Triage

  • needs-triage
  • bug

cc @junrushao @shingjan

tinywisdom avatar Oct 30 '25 07:10 tinywisdom