Mamba2 assertion error
Hi, when running example inference on Mamba2:
python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
An assertion error on the shape of dt is raised:
$ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2
Loading model state-spaces/mamba2-2.7b
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Number of parameters: 2702599680
Traceback (most recent call last):
File "/home/ec2-user/workspace/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 82, in <module>
out = fn()
File "/home/ec2-user/workspace/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 56, in <lambda>
fn = lambda: model.generate(
File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 260, in generate
output = decode(
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 221, in decode
scores.append(get_logits(sequences[-1], inference_params))
File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 184, in get_logits
logits = model(
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/workspace/mamba/mamba_ssm/models/mixer_seq_simple.py", line 281, in forward
hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/workspace/mamba/mamba_ssm/models/mixer_seq_simple.py", line 195, in forward
hidden_states, residual = layer(
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/workspace/mamba/mamba_ssm/modules/block.py", line 67, in forward
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/home/ec2-user/workspace/mamba/mamba_ssm/modules/mamba2.py", line 226, in forward
y = mamba_chunk_scan_combined(
File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 563, in mamba_chunk_scan_combined
return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 529, in forward
out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 286, in _mamba_chunk_scan_combined_fwd
assert dt.shape == (batch, seqlen, nheads)
AssertionError
Further inspection shows that at the point of the AssertionError:
dt.shape == torch.Size([1, 14, 80])
x.shape == torch.Size([1, 17, 80, 64])
B.shape == torch.Size([1, 17, 1, 128])
Seems like the seqlen of x has 3 more than dt, which caused the assertion error. I wonder if anyone else is also getting this error and what could be potentially causing trouble here.
Simply put, there may have been an error during the installation of your casual_conv1d package. Currently, your code is actually running through "nn. Conv1d", which implements the casual_conv1d logic through a padding scheme. Therefore, the actual output needs to be truncated.
You can change lines 214 to 216 in the Mamba2 source code to
xBC = self.act(self.conv1d(xBC.transpose(1, 2))[:, :, :seqlen].transpose(1, 2))
For details, please check #437
Update: when I reinstalled conv-1d library, the latest commit code worked. Thanks!
Hi @AlwaysFHao
I am using official github version that is based on commit 03a38fb.
I used what you described, but I got into an error.
Here is my code:
model = MambaLMHeadModel.from_pretrained (pretrained_model_name="state-spaces/mamba2-130m")
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
text = "The text of the declaration of independence is:"
inputs = {'input_ids': tokenizer(text, return_tensors="pt")['input_ids'].to(device) }
input_ids = inputs['input_ids']
model.to(device)
out= model.generate (input_ids, max_length=100, temperature=0)
Here is the last part of the error log:
File ~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:194, in MixerModel.forward(self, input_ids, inference_params, **mixer_kwargs)
[192](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:192) residual = None
[193](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:193) for layer in self.layers:
--> [194](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:194) hidden_states, residual = layer(
[195](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:195) hidden_states, residual, inference_params=inference_params, **mixer_kwargs
[196](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:196) )
[197](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:197) if not self.fused_add_norm:
[198](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:198) residual = (hidden_states + residual) if residual is not None else hidden_states
...
---> [81](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/miniconda3/envs/ssm/lib/python3.9/site-packages/triton/runtime/autotuner.py:81) self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File <string>:65, in _chunk_scan_fwd_kernel(cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_D_head, IS_CAUSAL, HAS_D, D_HAS_HDIM, HAS_Z, HAS_SEQ_IDX, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, BLOCK_SIZE_DSTATE, IS_TRITON_22, grid, num_warps, num_stages, extern_libs, stream, warmup, device, device_type)
ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
My problem is solved by installing causal-conv1d==1.4 and using the causal-conv1d-fn. I don't know why mamba's implementation for conv1d fails though.