mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Mamba-2: IndexError: map::at

Open Prophet-Kathleen opened this issue 1 year ago • 18 comments

Hi, here is the code:

import torch
from mamba_ssm import Mamba2

batch, length, dim = 2, 64, 256
x = torch.randn(batch, length, dim).to("cuda")

model = Mamba2(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=64,  # SSM state expansion factor, typically 64 or 128
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

y = model(x)

print(x.shape)
print(y.shape)

and here is ERROR

Traceback (most recent call last):
  File "/samba/network-storage/toshiba/netcode/mynet/network/Mamba_Based/Mamba2.py", line 15, in <module>
    y = model(x)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py", line 176, in forward
    out = mamba_split_conv1d_scan_combined(
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 908, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 115, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 773, in forward
    out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 308, in _mamba_chunk_scan_combined_fwd
    states = _chunk_state_fwd(B, x, dt, dA_cumsum, seq_idx=seq_idx, states_in_fp32=True)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 653, in _chunk_state_fwd
    _chunk_state_fwd_kernel[grid](
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/runtime/jit.py", line 167, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 122, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/testing.py", line 102, in do_bench
    fn()
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 110, in kernel_call
    self.fn.run(
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/runtime/jit.py", line 416, in run
    self.cache[device][key] = compile(
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/compiler/compiler.py", line 193, in compile
    next_module = compile_ir(module, metadata)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 199, in <lambda>
    stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
  File "/samba/network-storage/toshiba/cache/envs/mamba-env/lib/python3.10/site-packages/triton/compiler/backends/cuda.py", line 173, in make_llir
    ret = translate_triton_gpu_to_llvmir(src, capability, tma_infos, runtime.TARGET.NVVM)
IndexError: map::at

and ChatGPT4o say that there is something wrong with the triton package or mamba_split_conv1d_scan_combined

The error message indicates an IndexError: map::at occurred during the forward pass of the model model2. This error typically happens when trying to access a non-existent key in a map or dictionary.

From the traceback, we see that the error originates from within the mamba_split_conv1d_scan_combined function, specifically within the _chunk_state_fwd function when calling _chunk_state_fwd_kernel.

Possible Causes
Incorrect Parameter Passing:

Parameters passed to mamba_split_conv1d_scan_combined or _chunk_state_fwd might be incorrect, leading to errors when accessing data within the function.
CUDA Compatibility Issues:

There might be compatibility issues with the CUDA version or hardware, causing errors during kernel compilation by the Triton compiler.
Function Implementation Issues:

There might be issues in the implementation of _chunk_state_fwd or related functions, particularly when accessing elements in a map.
Solutions
Check Parameter Passing:

Ensure that all parameters passed to mamba_split_conv1d_scan_combined are of the correct shape and type.
Update CUDA and Triton:

Ensure that the CUDA version and Triton compiler are up-to-date and compatible with your hardware.
Debug Function Implementation:

Add debugging information within the _chunk_state_fwd function to identify which key is causing the map::at error.

and here is my Conda env

# packages in environment at /samba/network-storage/toshiba/cache/envs/mamba-env:
#
# Name                    Version                   Build  Channel
_libgcc_mutex             0.1                        main    defaults
_openmp_mutex             5.1                       1_gnu    defaults
_sysroot_linux-64_curr_repodata_hack 3                   haa98f57_10    defaults
binutils_impl_linux-64    2.38                 h2a08ee3_1    defaults
binutils_linux-64         2.38.0               hc2dff05_0    defaults
bzip2                     1.0.8                h5eee18b_6    defaults
ca-certificates           2024.3.11            h06a4308_0    defaults
causal-conv1d             1.2.2.post1              pypi_0    pypi
certifi                   2024.6.2                 pypi_0    pypi
charset-normalizer        3.3.2                    pypi_0    pypi
cuda-cccl_linux-64        12.5.39                       0    nvidia
cuda-crt-dev_linux-64     12.5.40                       0    nvidia
cuda-crt-tools            12.5.40                       0    nvidia
cuda-cudart               12.5.39                       0    nvidia
cuda-cudart-dev           12.5.39                       0    nvidia
cuda-cudart-dev_linux-64  12.5.39                       0    nvidia
cuda-cudart-static        12.5.39                       0    nvidia
cuda-cudart-static_linux-64 12.5.39                       0    nvidia
cuda-cudart_linux-64      12.5.39                       0    nvidia
cuda-driver-dev_linux-64  12.5.39                       0    nvidia
cuda-nvcc                 12.5.40                       0    nvidia
cuda-nvcc-dev_linux-64    12.5.40                       0    nvidia
cuda-nvcc-impl            12.5.40                       0    nvidia
cuda-nvcc-tools           12.5.40                       0    nvidia
cuda-nvcc_linux-64        12.5.40                       0    nvidia
cuda-nvvm-dev_linux-64    12.5.40                       0    nvidia
cuda-nvvm-impl            12.5.40                       0    nvidia
cuda-nvvm-tools           12.5.40                       0    nvidia
cuda-version              12.5                          3    nvidia
einops                    0.8.0                    pypi_0    pypi
filelock                  3.14.0                   pypi_0    pypi
fsspec                    2024.6.0                 pypi_0    pypi
gcc_impl_linux-64         11.2.0               h1234567_1    defaults
gcc_linux-64              11.2.0               h5c386dc_0    defaults
gxx_impl_linux-64         11.2.0               h1234567_1    defaults
gxx_linux-64              11.2.0               hc2dff05_0    defaults
huggingface-hub           0.23.2                   pypi_0    pypi
idna                      3.7                      pypi_0    pypi
jinja2                    3.1.4                    pypi_0    pypi
kernel-headers_linux-64   3.10.0              h57e8cba_10    defaults
ld_impl_linux-64          2.38                 h1181459_1    defaults
libffi                    3.4.4                h6a678d5_1    defaults
libgcc-devel_linux-64     11.2.0               h1234567_1    defaults
libgcc-ng                 11.2.0               h1234567_1    defaults
libgomp                   11.2.0               h1234567_1    defaults
libstdcxx-devel_linux-64  11.2.0               h1234567_1    defaults
libstdcxx-ng              11.2.0               h1234567_1    defaults
libuuid                   1.41.5               h5eee18b_0    defaults
mamba-ssm                 2.0.3                    pypi_0    pypi
markupsafe                2.1.5                    pypi_0    pypi
mpmath                    1.3.0                    pypi_0    pypi
ncurses                   6.4                  h6a678d5_0    defaults
networkx                  3.3                      pypi_0    pypi
ninja                     1.11.1.1                 pypi_0    pypi
numpy                     1.26.4                   pypi_0    pypi
nvidia-cublas-cu12        12.1.3.1                 pypi_0    pypi
nvidia-cuda-cupti-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-nvrtc-cu12    12.1.105                 pypi_0    pypi
nvidia-cuda-runtime-cu12  12.1.105                 pypi_0    pypi
nvidia-cudnn-cu12         8.9.2.26                 pypi_0    pypi
nvidia-cufft-cu12         11.0.2.54                pypi_0    pypi
nvidia-curand-cu12        10.3.2.106               pypi_0    pypi
nvidia-cusolver-cu12      11.4.5.107               pypi_0    pypi
nvidia-cusparse-cu12      12.1.0.106               pypi_0    pypi
nvidia-nccl-cu12          2.20.5                   pypi_0    pypi
nvidia-nvjitlink-cu12     12.5.40                  pypi_0    pypi
nvidia-nvtx-cu12          12.1.105                 pypi_0    pypi
openssl                   3.0.13               h7f8727e_2    defaults
packaging                 24.0                     pypi_0    pypi
pillow                    10.3.0                   pypi_0    pypi
pip                       24.0            py310h06a4308_0    defaults
python                    3.10.14              h955ad1f_1    defaults
pyyaml                    6.0.1                    pypi_0    pypi
readline                  8.2                  h5eee18b_0    defaults
regex                     2024.5.15                pypi_0    pypi
requests                  2.32.3                   pypi_0    pypi
safetensors               0.4.3                    pypi_0    pypi
setuptools                69.5.1          py310h06a4308_0    defaults
sqlite                    3.45.3               h5eee18b_0    defaults
sympy                     1.12.1                   pypi_0    pypi
sysroot_linux-64          2.17                h57e8cba_10    defaults
timm                      1.0.3                    pypi_0    pypi
tk                        8.6.14               h39e8969_0    defaults
tokenizers                0.19.1                   pypi_0    pypi
torch                     2.3.0                    pypi_0    pypi
torchaudio                2.3.0                    pypi_0    pypi
torchvision               0.18.0                   pypi_0    pypi
tqdm                      4.66.4                   pypi_0    pypi
transformers              4.41.2                   pypi_0    pypi
triton                    2.3.0                    pypi_0    pypi
typing-extensions         4.12.1                   pypi_0    pypi
tzdata                    2024a                h04d1e81_0    defaults
urllib3                   2.2.1                    pypi_0    pypi
wheel                     0.43.0          py310h06a4308_0    defaults
xz                        5.4.6                h5eee18b_1    defaults
zlib                      1.2.13               h5eee18b_1    defaults

now i just wanna to try this new Mamba2 as demo and how do i fix that problem?

Prophet-Kathleen avatar Jun 05 '24 05:06 Prophet-Kathleen

Looks like a Triton error, which GPU do you use?

tridao avatar Jun 05 '24 05:06 tridao

rtx2080ti with driver version 535.171.04 right now, i am trying to use different version of python or newest trition 2.3.1 or pytorch to fix this problem

Prophet-Kathleen avatar Jun 05 '24 06:06 Prophet-Kathleen

I'm not sure triton supports GPUs before Ampere (e.g. 2080) very well

tridao avatar Jun 05 '24 06:06 tridao

I just borrowed a rtx3060 (driver version 535.171.04) to test the code and it works. and triton version is 2.3.1. thanks

Prophet-Kathleen avatar Jun 05 '24 10:06 Prophet-Kathleen

same issue here but using V100 :(

ghaddarAbs avatar Jun 06 '24 01:06 ghaddarAbs

With newest triton version (2.3.1), this seems mainly related to the used GPU. I also ran into this error on an RTX 2080 Ti, thus I tried to reproduce the error on different GPUs I have available.

Setup: Nvidia driver 535.161.07, Cuda 11.8, Triton 2.3.1, mamba-ssm v2.0.3 Working GPUs: V100, RTX 3090, RTX 4090, A100 (40GB & 80GB) Index Error (map::at): RTX 2080 Ti, Titan RTX, Quadro RTX 6000 --> it seems the error only occurs for the Turing microarchitecture

@ghaddarAbs V100 works for me; maybe update triton to 2.3.1?

jsie7 avatar Jun 07 '24 12:06 jsie7

@jsie7 thanks for suggesting I will try it out ... which torch version you used ?

ghaddarAbs avatar Jun 07 '24 15:06 ghaddarAbs

@jsie7 thanks for suggesting I will try it out ... which torch version you used ?

I'm using v2.0.1

jsie7 avatar Jun 07 '24 15:06 jsie7

I experienced same problem :(

Driver Version: 550.54.14      
CUDA Version: 12.4 # by nvcc -V
Tesla T4  
triton==2.3.1

I only installed torch with cuda12.1 support.

SolomidHero avatar Jun 17 '24 14:06 SolomidHero

I experienced same problem :(

Driver Version: 550.54.14      
CUDA Version: 12.4 # by nvcc -V
Tesla T4  
triton==2.3.1

I only installed torch with cuda12.1 support.

The Tesla T4 is also based on the Turing microarchitecture. This just further confirms that it's an issue with that architecture.

jsie7 avatar Jun 18 '24 06:06 jsie7

I found a workaround, at least for T4. Set the model to half-precision to avoid this error:

model = MambaLMHeadModel.from_pretrained('state-spaces/mamba2-130m')
model = model.half()

Let me know if you understand why this works :)

maksymdolgikh avatar Jun 20 '24 18:06 maksymdolgikh

Thanks @maksymdolgikh this worked for me, also no idea why this works though.

JulienSiems avatar Jun 27 '24 08:06 JulienSiems

same in 2080Ti,Cu11.8 torch2.00 Triton==2.0.0

MstarLioning avatar Jul 24 '24 08:07 MstarLioning

same in 1660Ti, Cu12.5, torch2.4.0 Triton==3.0.0. Works only if I do model.half().

Aryan-Satpathy avatar Aug 01 '24 04:08 Aryan-Satpathy

对于最新的triton版本(2.3.1),这似乎主要与使用的GPU有关。我在 RTX 2080 Ti 上也遇到了这个错误,因此我试图在我可用的不同 GPU 上重现该错误。

**设置:**Nvidia 驱动程序 535.161.07、Cuda 11.8、Triton 2.3.1、mamba-ssm v2.0.3 **工作 GPU:**V100、RTX 3090、RTX 4090、A100(40GB 和 80GB) **索引错误 (map::at):**RTX 2080 Ti、Titan RTX、Quadro RTX 6000 --> 似乎错误只发生在图灵微架构上

V100 对我有用;也许将 Triton 更新到 2.3.1? supply: Quadro RTX 4000 switch Triton=3.0.0 to 2.1.0 worked, but infer slowly then Mamba

Jon-Zbw avatar Aug 21 '24 05:08 Jon-Zbw

Yes, I am running it on a 2080 Ti as well, and I encountered the same error. However, I am running it on Windows, and it only works by setting it to half precision.

Drizzle-02 avatar Jan 15 '25 12:01 Drizzle-02

I investigated a bit further and found a few issues in the triton repo describing the same problem (e.g. https://github.com/triton-lang/triton/issues/4813, https://github.com/triton-lang/triton/issues/3011). Triton doesn't seem to correctly handle the case when TF32 tensor core matmuls are not supported by an architecture such as Turing (which only supports fp16).

According to the triton/README.md, the default input precision for the tensor core matmuls can be controlled using the TRITON_F32_DEFAULT environment variable. I was able to solve the problem on a Google Colab (T4) using:

import os
os.environ["TRITON_F32_DEFAULT"] = "ieee"

This solution only reduces the precision of the matmuls instead of the full model and activations.

safelix avatar Mar 24 '25 15:03 safelix