mamba
mamba copied to clipboard
ValueError: Pointer argument (at 9) cannot be accessed from Triton (cpu tensor?)
I am attempting to use cuda graphs with cudaMallocAsync and mamba. The code seems to work fine if I am using the regular allocater and cuda graphs but I am getting errors below with the async allocater. Other setups I have seen work with this but the ssm kernels seem to hit this weird error even though the pointer arguments are on gpu. I was wondering if the team had seen errors like this before?
Traceback (most recent call last):
File "/lustre/fs1/portfolios/llmservice/users/wdykas/mamba-inference/megatron-lm/test_graphs.py", line 148, in <module>
main()
File "/lustre/fs1/portfolios/llmservice/users/wdykas/mamba-inference/megatron-lm/test_graphs.py", line 118, in main
y = selective_state_update(
File "/usr/local/lib/python3.10/dist-packages/mamba_ssm/ops/triton/selective_state_update.py", line 181, in selective_state_update
_selective_scan_update_kernel[grid](
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
return self.fn.run(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
return self.fn.run(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/autotuner.py", line 338, in run
return self.fn.run(*args, **kwargs)
[Previous line repeated 1 more time]
File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 691, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/usr/local/lib/python3.10/dist-packages/triton/backends/nvidia/driver.py", line 365, in __call__
self.launch(*args, **kwargs)
ValueError: Pointer argument (at 9) cannot be accessed from Triton (cpu tensor?)
#!/usr/bin/env python
import os
# Use cudaMallocAsync for performance.
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
import torch
import time
import math
from einops import rearrange, repeat
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
try:
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
raise ImportError("selective_state_update kernel is required for this test.")
def main():
device = torch.device("cuda")
# Dummy dimensions (example values)
batch = 2
nheads = 4 # number of heads
headdim = 16 # head dimension (p)
d_state = 32 # state dimension
ngroups = 2 # groups for B and C
# Create dummy tensors on GPU
ssm_state = torch.randn(batch, nheads, headdim, d_state, device=device, dtype=torch.float32)
x = torch.randn(batch, nheads * headdim, device=device, dtype=torch.float32)
# Reshape x to shape [batch, nheads, headdim]
x_reshaped = rearrange(x, "b (h p) -> b h p", p=headdim)
dt = torch.randn(batch, nheads, device=device, dtype=torch.float32)
dt_bias = torch.randn(nheads, device=device, dtype=torch.float32)
A = torch.randn(nheads, device=device, dtype=torch.float32)
B = torch.randn(batch, ngroups, d_state, device=device, dtype=torch.float32)
C = torch.randn(batch, ngroups, d_state, device=device, dtype=torch.float32)
D = torch.randn(nheads, device=device, dtype=torch.float32)
z = torch.randn(batch, nheads, headdim, device=device, dtype=torch.float32)
# Mimic the repeats and rearrangements used in the main code:
A_rep = repeat(A, "h -> h p n", p=headdim, n=d_state)
dt_rep = repeat(dt, "b h -> b h p", p=headdim)
dt_bias_rep = repeat(dt_bias, "h -> h p", p=headdim)
D_rep = repeat(D, "h -> h p", p=headdim)
# Warm-up call (outside graph capture)
_ = selective_state_update(
ssm_state,
x_reshaped,
dt_rep,
A_rep,
B,
C,
D_rep,
z=z,
dt_bias=dt_bias_rep,
dt_softplus=True, # using original boolean
)
torch.cuda.synchronize()
# Clone static buffers and explicitly ensure they are on the correct device
static_ssm_state = ssm_state.clone().to(device)
static_x_reshaped = x_reshaped.clone().to(device)
static_dt_rep = dt_rep.clone().to(device)
static_A_rep = A_rep.clone().to(device)
static_B = B.clone().to(device)
static_C = C.clone().to(device)
static_D_rep = D_rep.clone().to(device)
static_dt_bias_rep = dt_bias_rep.clone().to(device)
static_z = z.clone().to(device)
# Verify all tensors are on the correct device before proceeding
tensor_names = ["static_ssm_state", "static_x_reshaped", "static_dt_rep",
"static_A_rep", "static_B", "static_C", "static_D_rep",
"static_z", "static_dt_bias_rep"]
tensors = [static_ssm_state, static_x_reshaped, static_dt_rep,
static_A_rep, static_B, static_C, static_D_rep,
static_z, static_dt_bias_rep]
for i, (name, tensor) in enumerate(zip(tensor_names, tensors)):
if tensor.device.type != 'cuda':
logger.error(f"Tensor '{name}' at position {i} is on {tensor.device} instead of CUDA!")
tensor = tensor.to(device) # Try to fix it
logger.info(f"Moved '{name}' to {tensor.device}")
# Log the device of all tensors before graph capture
logger.debug("Before graph capture: ssm_state.device=%s, x_reshaped.device=%s, dt_rep.device=%s, A_rep.device=%s, B.device=%s, C.device=%s, D_rep.device=%s, dt_bias_rep.device=%s, z.device=%s",
static_ssm_state.device, static_x_reshaped.device, static_dt_rep.device,
static_A_rep.device, static_B.device, static_C.device,
static_D_rep.device, static_dt_bias_rep.device, static_z.device)
# Warm-up the static buffers with a dummy kernel call
_ = selective_state_update(
static_ssm_state,
static_x_reshaped,
static_dt_rep,
static_A_rep,
static_B,
static_C,
static_D_rep,
z=static_z,
dt_bias=static_dt_bias_rep,
dt_softplus=True,
)
torch.cuda.synchronize()
# Create a non-default stream for graph capture.
capture_stream = torch.cuda.Stream(device=device)
graph = torch.cuda.CUDAGraph()
# Make sure all operations are completed before starting graph capture
torch.cuda.synchronize()
with torch.cuda.stream(capture_stream):
graph.capture_begin()
y = selective_state_update(
static_ssm_state,
static_x_reshaped,
static_dt_rep,
static_A_rep,
static_B,
static_C,
static_D_rep,
z=static_z,
dt_bias=static_dt_bias_rep,
dt_softplus=True,
)
graph.capture_end()
torch.cuda.synchronize()
# Replay the captured graph 100 times.
start = time.time()
for i in range(100):
graph.replay()
torch.cuda.synchronize()
end = time.time()
logger.debug("After selective_state_update: ssm_state.device=%s, x_reshaped.device=%s, dt_rep.device=%s, A_rep.device=%s, B.device=%s, C.device=%s, D_rep.device=%s, dt_bias_rep.device=%s, z.device=%s",
static_ssm_state.device, static_x_reshaped.device, static_dt_rep.device,
static_A_rep.device, static_B.device, static_C.device,
static_D_rep.device, static_dt_bias_rep.device, static_z.device)
print("Selective state update graph replay time over 100 iterations: {:.4f} sec".format(end - start))
print("Output sample:\n", y)
if __name__ == "__main__":
main()
I have the same question. do you solve it ?please let me kown your solution