mamba icon indicating copy to clipboard operation
mamba copied to clipboard

triton error while running Mamba2 with slow path

Open Seeker98 opened this issue 8 months ago • 10 comments

as #355 , I added "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" to "mamba_chunk_scan_combined" function in file "ssd_combined.py", and running failed with error:

Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>

from user code:
   File "/home/hit/.conda/envs/torch2/lib/python3.9/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 560, in mamba_chunk_scan_combined
    return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

reproduce code:

import torch
from mamba_ssm import Mamba2
batch, length, dim = 8,1024,128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=32,
    use_mem_eff_path=False
).to("cuda")
y = model(x)
assert y.shape == x.shape

I'm not sure what to provide, but my packages are: mamba-ssm 2.0.3 causal-conv1d 1.2.2.post1 pytorch 2.3.1 with py39_cu121_cudnn8.9.2_0

Seeker98 avatar Jun 06 '24 13:06 Seeker98