mamba icon indicating copy to clipboard operation
mamba copied to clipboard

RuntimeError: Triton Error [CUDA]: context is destroyed

Open yingzhige00 opened this issue 8 months ago • 8 comments

This is my program error:

Traceback (most recent call last):
  File "/root/mamba_test.py", line 34, in <module>
    mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
  File "/opt/conda/envs/mamba2/lib/python3.10/timeit.py", line 234, in timeit
    return Timer(stmt, setup, timer, globals).timeit(number)
  File "/opt/conda/envs/mamba2/lib/python3.10/timeit.py", line 178, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/root/mamba_test.py", line 28, in try_mamba2
    y = model(x)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/modules/mamba2.py", line 176, in forward
    out = mamba_split_conv1d_scan_combined(
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 912, in mamba_split_conv1d_scan_combined
    return MambaSplitConv1dScanCombinedFn.apply(zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states, seq_idx, dt_limit, return_final_states, activation, rmsnorm_weight, rmsnorm_eps, outproj_weight, outproj_bias, headdim, ngroups, norm_before_gate)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/torch/cuda/amp/autocast_mode.py", line 113, in decorate_fwd
    return fwd(*args, **kwargs)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 777, in forward
    out_x, _, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size=chunk_size, D=D, z=None, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=True, dt_limit=dt_limit)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_combined.py", line 323, in _mamba_chunk_scan_combined_fwd
    out, out_x = _chunk_scan_fwd(CB, x, dt, dA_cumsum, C, states, D=D, z=z, seq_idx=seq_idx)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/mamba_ssm/ops/triton/ssd_chunk_scan.py", line 1253, in _chunk_scan_fwd
    _chunk_scan_fwd_kernel[grid](
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 100, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 83, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/triton/testing.py", line 104, in do_bench
    fn()
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 81, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 65, in _chunk_scan_fwd_kernel
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/triton/compiler/compiler.py", line 579, in __getattribute__
    self._init_handles()
  File "/opt/conda/envs/mamba2/lib/python3.10/site-packages/triton/compiler/compiler.py", line 570, in _init_handles
    mod, func, n_regs, n_spills = fn_load_binary(self.metadata["name"], self.asm[bin_path], self.shared, device)
RuntimeError: Triton Error [CUDA]: context is destroyed

This is my code:

import torch
import timeit
from mamba_ssm import Mamba, Mamba2

batch, length, dim = 2, 64, 256
cuda = "cuda:4"
x = torch.randn(batch, length, dim).to(cuda)

def try_mamba1(batch, length, dim, x):
    model = Mamba(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=16,  # SSM state expansion factor
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to(cuda)
    y = model(x)
    assert y.shape == x.shape

def try_mamba2(batch, length, dim, x):
    model = Mamba2(
        # This module uses roughly 3 * expand * d_model^2 parameters
        d_model=dim, # Model dimension d_model
        d_state=64,  # SSM state expansion factor, typically 64 or 128
        d_conv=4,    # Local convolution width
        expand=2,    # Block expansion factor
    ).to(cuda)
    y = model(x)
    assert y.shape == x.shape

mamba1_time = timeit.timeit('try_mamba1(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 1 took {mamba1_time} seconds")

mamba2_time = timeit.timeit('try_mamba2(batch, length, dim, x)', number=10, globals=globals())
print(f"Mamba 2 took {mamba2_time} seconds")

How to solve this error

yingzhige00 avatar Jun 12 '24 13:06 yingzhige00