mamba icon indicating copy to clipboard operation
mamba copied to clipboard

ValueError: Pointer argument (at 9) cannot be accessed from Triton (cpu tensor?)

Open wdykas opened this issue 10 months ago • 1 comments

I am attempting to use cuda graphs with cudaMallocAsync and mamba. The code seems to work fine if I am using the regular allocater and cuda graphs but I am getting errors below with the async allocater. Other setups I have seen work with this but the ssm kernels seem to hit this weird error even though the pointer arguments are on gpu. I was wondering if the team had seen errors like this before?

Traceback (most recent call last):
  File "/lustre/fs1/portfolios/llmservice/users/wdykas/mamba-inference/megatron-lm/test_graphs.py", line 148, in <module>
    main()
  File "/lustre/fs1/portfolios/llmservice/users/wdykas/mamba-inference/megatron-lm/test_graphs.py", line 118, in main
    y = selective_state_update(
  File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/selective_state_update.py", line 181, in selective_state_update
    _selective_scan_update_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
    return self.fn.run(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
    return self.fn.run(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
    return self.fn.run(*args, **kwargs)
  [Previous line repeated 1 more time]
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__
    self.launch(*args, **kwargs)
ValueError: Pointer argument (at 9) cannot be accessed from Triton (cpu tensor?)

#!/usr/bin/env python
import os
# Use cudaMallocAsync for performance.
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"

import torch
import time
import math
from einops import rearrange, repeat
import logging

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    raise ImportError("selective_state_update kernel is required for this test.")

def main():
    device = torch.device("cuda")
    # Dummy dimensions (example values)
    batch = 2
    nheads = 4         # number of heads
    headdim = 16       # head dimension (p)
    d_state = 32       # state dimension
    ngroups = 2        # groups for B and C

    # Create dummy tensors on GPU
    ssm_state = torch.randn(batch, nheads, headdim, d_state, device=device, dtype=torch.float32)
    x = torch.randn(batch, nheads * headdim, device=device, dtype=torch.float32)
    # Reshape x to shape [batch, nheads, headdim]
    x_reshaped = rearrange(x, "b (h p) -> b h p", p=headdim)
    dt = torch.randn(batch, nheads, device=device, dtype=torch.float32)
    dt_bias = torch.randn(nheads, device=device, dtype=torch.float32)
    A = torch.randn(nheads, device=device, dtype=torch.float32)
    B = torch.randn(batch, ngroups, d_state, device=device, dtype=torch.float32)
    C = torch.randn(batch, ngroups, d_state, device=device, dtype=torch.float32)
    D = torch.randn(nheads, device=device, dtype=torch.float32)
    z = torch.randn(batch, nheads, headdim, device=device, dtype=torch.float32)

    # Mimic the repeats and rearrangements used in the main code:
    A_rep = repeat(A, "h -> h p n", p=headdim, n=d_state)
    dt_rep = repeat(dt, "b h -> b h p", p=headdim)
    dt_bias_rep = repeat(dt_bias, "h -> h p", p=headdim)
    D_rep = repeat(D, "h -> h p", p=headdim)

    # Warm-up call (outside graph capture)
    _ = selective_state_update(
        ssm_state,
        x_reshaped,
        dt_rep,
        A_rep,
        B,
        C,
        D_rep,
        z=z,
        dt_bias=dt_bias_rep,
        dt_softplus=True,   # using original boolean
    )
    torch.cuda.synchronize()

    # Clone static buffers and explicitly ensure they are on the correct device
    static_ssm_state = ssm_state.clone().to(device)
    static_x_reshaped = x_reshaped.clone().to(device)
    static_dt_rep = dt_rep.clone().to(device)
    static_A_rep = A_rep.clone().to(device)
    static_B = B.clone().to(device)
    static_C = C.clone().to(device)
    static_D_rep = D_rep.clone().to(device)
    static_dt_bias_rep = dt_bias_rep.clone().to(device)
    static_z = z.clone().to(device)

    # Verify all tensors are on the correct device before proceeding
    tensor_names = ["static_ssm_state", "static_x_reshaped", "static_dt_rep", 
                   "static_A_rep", "static_B", "static_C", "static_D_rep", 
                   "static_z", "static_dt_bias_rep"]
    tensors = [static_ssm_state, static_x_reshaped, static_dt_rep, 
              static_A_rep, static_B, static_C, static_D_rep, 
              static_z, static_dt_bias_rep]
    
    for i, (name, tensor) in enumerate(zip(tensor_names, tensors)):
        if tensor.device.type != 'cuda':
            logger.error(f"Tensor '{name}' at position {i} is on {tensor.device} instead of CUDA!")
            tensor = tensor.to(device)  # Try to fix it
            logger.info(f"Moved '{name}' to {tensor.device}")

    # Log the device of all tensors before graph capture
    logger.debug("Before graph capture: ssm_state.device=%s, x_reshaped.device=%s, dt_rep.device=%s, A_rep.device=%s, B.device=%s, C.device=%s, D_rep.device=%s, dt_bias_rep.device=%s, z.device=%s",
                 static_ssm_state.device, static_x_reshaped.device, static_dt_rep.device,
                 static_A_rep.device, static_B.device, static_C.device,
                 static_D_rep.device, static_dt_bias_rep.device, static_z.device)

    # Warm-up the static buffers with a dummy kernel call
    _ = selective_state_update(
        static_ssm_state,
        static_x_reshaped,
        static_dt_rep,
        static_A_rep,
        static_B,
        static_C,
        static_D_rep,
        z=static_z,
        dt_bias=static_dt_bias_rep,
        dt_softplus=True,
    )
    torch.cuda.synchronize()

    # Create a non-default stream for graph capture.
    capture_stream = torch.cuda.Stream(device=device)
    graph = torch.cuda.CUDAGraph()
    
    # Make sure all operations are completed before starting graph capture
    torch.cuda.synchronize()
    
    with torch.cuda.stream(capture_stream):
        graph.capture_begin()
        y = selective_state_update(
            static_ssm_state,
            static_x_reshaped,
            static_dt_rep,
            static_A_rep,
            static_B,
            static_C,
            static_D_rep,
            z=static_z,
            dt_bias=static_dt_bias_rep,
            dt_softplus=True,
        )
        graph.capture_end()
    torch.cuda.synchronize()

    # Replay the captured graph 100 times.
    start = time.time()
    for i in range(100):
        graph.replay()
    torch.cuda.synchronize()
    end = time.time()

    logger.debug("After selective_state_update: ssm_state.device=%s, x_reshaped.device=%s, dt_rep.device=%s, A_rep.device=%s, B.device=%s, C.device=%s, D_rep.device=%s, dt_bias_rep.device=%s, z.device=%s",
                 static_ssm_state.device, static_x_reshaped.device, static_dt_rep.device,
                 static_A_rep.device, static_B.device, static_C.device,
                 static_D_rep.device, static_dt_bias_rep.device, static_z.device)
    print("Selective state update graph replay time over 100 iterations: {:.4f} sec".format(end - start))
    print("Output sample:\n", y)

if __name__ == "__main__":
    main()

wdykas avatar Mar 04 '25 02:03 wdykas

I have the same question. do you solve it ?please let me kown your solution

2020chenlin avatar Mar 06 '25 03:03 2020chenlin