mamba icon indicating copy to clipboard operation
mamba copied to clipboard

triton error while running Mamba2 with slow path

Open Seeker98 opened this issue 1 year ago • 10 comments

as #355 , I added "@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)" to "mamba_chunk_scan_combined" function in file "ssd_combined.py", and running failed with error:

Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>

from user code:
   File "/home/hit/.conda/envs/torch2/lib/python3.9/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 560, in mamba_chunk_scan_combined
    return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

reproduce code:

import torch
from mamba_ssm import Mamba2
batch, length, dim = 8,1024,128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=32,
    use_mem_eff_path=False
).to("cuda")
y = model(x)
assert y.shape == x.shape

I'm not sure what to provide, but my packages are: mamba-ssm 2.0.3 causal-conv1d 1.2.2.post1 pytorch 2.3.1 with py39_cu121_cudnn8.9.2_0

Seeker98 avatar Jun 06 '24 13:06 Seeker98

Tried the following and time seems not to change. Maybe this is just an initial delay:


for i in range(10):
    x = torch.randn(batch, length, dim).to("cuda")
    y = model2(x)

arelkeselbri avatar Jun 06 '24 14:06 arelkeselbri

Well I’m wondering about why adding compile as #355 discussion makes the code failed, as the author mentioned this could accelerate a lot

Seeker98 avatar Jun 06 '24 14:06 Seeker98

the same issue

Baijiong-Lin avatar Jun 07 '24 08:06 Baijiong-Lin

同样的问题

yaosi-ym avatar Jun 10 '24 07:06 yaosi-ym

the same issue

zizheng-guo avatar Jun 14 '24 06:06 zizheng-guo

如 #355,我在文件“ssd_combined.py”中的“mamba_chunk_scan_combined”函数中添加了“@torch.compile(options={"triton.cudagraphs": True}, fullgraph=True)”,运行失败,错误如下:

Unsupported: autograd.Function with body that accepts non-Tensors as input. Got: <class 'tuple'>

from user code:
   File "/home/hit/.conda/envs/torch2/lib/python3.9/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 560, in mamba_chunk_scan_combined
    return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

重现代码:

import torch
from mamba_ssm import Mamba2
batch, length, dim = 8,1024,128
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    headdim=32,
    use_mem_eff_path=False
).to("cuda")
y = model(x)
assert y.shape == x.shape

我不确定要提供什么,但我的包是: mamba-ssm 2.0.3 causal-conv1d 1.2.2.post1 pytorch 2.3.1 和 py39_cu121_cudnn8.9.2_0

Hi, I have the same problem, have you solved it?

c-junhao avatar Jun 16 '24 12:06 c-junhao

the same issue

TimothyChen225 avatar Jun 18 '24 05:06 TimothyChen225

Same here

JulienSiems avatar Jul 13 '24 14:07 JulienSiems

Same here...

SH-Yoon-01 avatar Jul 30 '24 19:07 SH-Yoon-01

same issue here

florinshen avatar Sep 05 '24 07:09 florinshen

same issue here,do anyone solve it, very thanks! 同样的问题

yingyingSun-01 avatar Oct 30 '24 03:10 yingyingSun-01