mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Long Sequence Length Inference Mamba2: CUDA error: an illegal memory access was encountered

Open wdykas opened this issue 11 months ago • 1 comments

When running simple inference with Mamba 2 on H100 on long sequence lengths=512000. I am hitting illegal memory access in _mamba_chunk_scan_combined_fwd:

  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_combined.py", line 315, in _mamba_chunk_scan_combined_fwd
    dA_cumsum, dt = _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus, dt_limit=dt_limit)
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/ssd_chunk_state.py", line 684, in _chunk_cumsum_fwd
    _chunk_cumsum_fwd_kernel[grid_chunk_cs](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 156, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 133, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
  File "/usr/local/lib/python3.10/dist-packages/triton/testing.py", line 104, in do_bench
    torch.cuda.synchronize()
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/__init__.py", line 950, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered

Here is a small reproducer:

import torch
from mamba_ssm import Mamba2
import torch.cuda as cuda

def get_gpu_memory():
    """Return GPU memory usage in MB"""
    return cuda.memory_allocated() / 1024**2, cuda.memory_reserved() / 1024**2

def print_memory_usage(step):
    allocated, reserved = get_gpu_memory()
    print(f"{step}:")
    print(f"  Allocated: {allocated:.2f} MB")
    print(f"  Reserved:  {reserved:.2f} MB")
    print("-" * 40)

# Model parameters
d_model = 4096
d_state = 128
d_conv = 4
expand = 2
seq_len = 512000
batch_size = 1

# Initialize device and print initial memory state
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Initial GPU memory state")
print_memory_usage("Before model creation")

# Create model
model = Mamba2(
    d_model=d_model,
    d_state=d_state,
    d_conv=d_conv,
    expand=expand,
).to(device)

print("After model creation")
print_memory_usage("After moving model to GPU")

# Create input tensor
x = torch.randn(batch_size, seq_len, d_model).to(device)
print_memory_usage("After creating input tensor")

# Run inference multiple times
n_repeats = 3
with torch.no_grad():
    for i in range(n_repeats):
        # Clear cache before each run
        torch.cuda.empty_cache()
        print(f"\nInference run {i+1}")
        print_memory_usage("Before inference")
        
        # Run inference
        output = model(x)
        print_memory_usage("After inference")
        
        # Verify output shape
        assert output.shape == x.shape, f"Expected shape {x.shape}, got {output.shape}"
        print(f"Output shape: {output.shape}")

# Final memory state
print("\nFinal GPU memory state")
print_memory_usage("After all runs")

# Model size information
total_params = sum(p.numel() for p in model.parameters())
print(f"\nModel Information:")
print(f"Total parameters: {total_params:,}")
print(f"Theoretical model size: {total_params * 4 / 1024**2:.2f} MB (FP32)")

The first inference seems to run fine but then the second there is an illegal memory access. Is there some modifications needed to the mamba2 kernels to for increasing sequence length? I am assuming this may be an range indexing issue of some kind.

wdykas avatar Feb 05 '25 17:02 wdykas

fixed by changing indexing to int64 in kernels

wdykas avatar Feb 11 '25 17:02 wdykas

Hi, just to confirm: If I apply this commit (https://github.com/wdykas/mamba/commit/bfec072693f050505b3b28f25bf532c2c9623ded), will it resolve the illegal memory access issue for long sequence lengths? I have struggled with a similar problem before. Thank you!

klae01 avatar May 18 '25 07:05 klae01