mamba
mamba copied to clipboard
triton error while running Mamba2 with slow path
as #355 , I added "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" to "mamba_chunk_scan_combined" function in file "ssd_combined.py", and running failed with error:
Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>
from user code:
File "/home/hit/.conda/envs/torch2/lib/python3.9/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 560, in mamba_chunk_scan_combined
return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
reproduce code:
import torch
from mamba_ssm import Mamba2
batch, length, dim = 8,1024,128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
# This module uses roughly 3 * expand * d_model^2 parameters
d_model=dim, # Model dimension d_model
d_state=64, # SSM state expansion factor, typically 64 or 128
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
headdim=32,
use_mem_eff_path=False
).to("cuda")
y = model(x)
assert y.shape == x.shape
I'm not sure what to provide, but my packages are: mamba-ssm 2.0.3 causal-conv1d 1.2.2.post1 pytorch 2.3.1 with py39_cu121_cudnn8.9.2_0