mamba icon indicating copy to clipboard operation
mamba copied to clipboard

Mamba2 assertion error

Open wyc1997 opened this issue 1 year ago • 3 comments

Hi, when running example inference on Mamba2:

python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2

An assertion error on the shape of dt is raised:

$ python benchmarks/benchmark_generation_mamba_simple.py --model-name "state-spaces/mamba2-2.7b" --prompt "My cat wrote all this CUDA code for a new language model and" --topp 0.9 --temperature 0.7 --repetition-penalty 1.2

Loading model state-spaces/mamba2-2.7b
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Number of parameters: 2702599680
Traceback (most recent call last):
  File "/home/ec2-user/workspace/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 82, in <module>
    out = fn()
  File "/home/ec2-user/workspace/mamba/benchmarks/benchmark_generation_mamba_simple.py", line 56, in <lambda>
    fn = lambda: model.generate(
  File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 260, in generate
    output = decode(
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 221, in decode
    scores.append(get_logits(sequences[-1], inference_params))
  File "/home/ec2-user/workspace/mamba/mamba_ssm/utils/generation.py", line 184, in get_logits
    logits = model(
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/models/mixer_seq_simple.py", line 281, in forward
    hidden_states = self.backbone(input_ids, inference_params=inference_params, **mixer_kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/models/mixer_seq_simple.py", line 195, in forward
    hidden_states, residual = layer(
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/modules/block.py", line 67, in forward
    hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/modules/mamba2.py", line 226, in forward
    y = mamba_chunk_scan_combined(
  File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 563, in mamba_chunk_scan_combined
    return MambaChunkScanCombinedFn.apply(x, dt, A, B, C, chunk_size, D, z, dt_bias, initial_states, seq_idx, dt_softplus, dt_limit, return_final_states)
  File "/opt/conda/envs/mamba/lib/python3.10/site-packages/torch/autograd/function.py", line 598, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 529, in forward
    out, out_x, dt_out, dA_cumsum, states, final_states = _mamba_chunk_scan_combined_fwd(x, dt, A, B, C, chunk_size, D=D, z=z, dt_bias=dt_bias, initial_states=initial_states, seq_idx=seq_idx, dt_softplus=dt_softplus, dt_limit=dt_limit)
  File "/home/ec2-user/workspace/mamba/mamba_ssm/ops/triton/ssd_combined.py", line 286, in _mamba_chunk_scan_combined_fwd
    assert dt.shape == (batch, seqlen, nheads)
AssertionError

Further inspection shows that at the point of the AssertionError: dt.shape == torch.Size([1, 14, 80]) x.shape == torch.Size([1, 17, 80, 64]) B.shape == torch.Size([1, 17, 1, 128]) Seems like the seqlen of x has 3 more than dt, which caused the assertion error. I wonder if anyone else is also getting this error and what could be potentially causing trouble here.

wyc1997 avatar Jun 21 '24 23:06 wyc1997

Simply put, there may have been an error during the installation of your casual_conv1d package. Currently, your code is actually running through "nn. Conv1d", which implements the casual_conv1d logic through a padding scheme. Therefore, the actual output needs to be truncated. You can change lines 214 to 216 in the Mamba2 source code to xBC = self.act(self.conv1d(xBC.transpose(1, 2))[:, :, :seqlen].transpose(1, 2)) For details, please check #437

AlwaysFHao avatar Jul 02 '24 08:07 AlwaysFHao

Update: when I reinstalled conv-1d library, the latest commit code worked. Thanks!

Hi @AlwaysFHao

I am using official github version that is based on commit 03a38fb.

I used what you described, but I got into an error.

Here is my code:

model = MambaLMHeadModel.from_pretrained (pretrained_model_name="state-spaces/mamba2-130m")
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-2.8b-hf")
text = "The text of the declaration of independence is:"
inputs = {'input_ids': tokenizer(text, return_tensors="pt")['input_ids'].to(device) }
input_ids = inputs['input_ids']
model.to(device)
out= model.generate (input_ids, max_length=100, temperature=0)

Here is the last part of the error log:

File ~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:194, in MixerModel.forward(self, input_ids, inference_params, **mixer_kwargs)
    [192](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:192) residual = None
    [193](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:193) for layer in self.layers:
--> [194](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:194)     hidden_states, residual = layer(
    [195](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:195)         hidden_states, residual, inference_params=inference_params, **mixer_kwargs
    [196](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:196)     )
    [197](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:197) if not self.fused_add_norm:
    [198](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/mamba_official/mamba_ssm/models/mixer_seq_simple.py:198)     residual = (hidden_states + residual) if residual is not None else hidden_states
...
---> [81](https://vscode-remote+ssh-002dremote-002blei-002dlab.vscode-resource.vscode-cdn.net/home/ziw081/mamba_official/~/miniconda3/envs/ssm/lib/python3.9/site-packages/triton/runtime/autotuner.py:81) self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)

File <string>:65, in _chunk_scan_fwd_kernel(cb_ptr, x_ptr, z_ptr, out_ptr, out_x_ptr, dt_ptr, dA_cumsum_ptr, seq_idx_ptr, C_ptr, prev_states_ptr, D_ptr, chunk_size, hdim, dstate, batch, seqlen, nheads_ngroups_ratio, stride_cb_batch, stride_cb_chunk, stride_cb_head, stride_cb_csize_m, stride_cb_csize_k, stride_x_batch, stride_x_seqlen, stride_x_head, stride_x_hdim, stride_z_batch, stride_z_seqlen, stride_z_head, stride_z_hdim, stride_out_batch, stride_out_seqlen, stride_out_head, stride_out_hdim, stride_dt_batch, stride_dt_chunk, stride_dt_head, stride_dt_csize, stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize, stride_seq_idx_batch, stride_seq_idx_seqlen, stride_C_batch, stride_C_seqlen, stride_C_head, stride_C_dstate, stride_states_batch, stride_states_chunk, stride_states_head, stride_states_hdim, stride_states_dstate, stride_D_head, IS_CAUSAL, HAS_D, D_HAS_HDIM, HAS_Z, HAS_SEQ_IDX, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K, BLOCK_SIZE_DSTATE, IS_TRITON_22, grid, num_warps, num_stages, extern_libs, stream, warmup, device, device_type)

ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

zixianwang2022 avatar Jul 07 '24 22:07 zixianwang2022

My problem is solved by installing causal-conv1d==1.4 and using the causal-conv1d-fn. I don't know why mamba's implementation for conv1d fails though.

wyc1997 avatar Jul 18 '24 21:07 wyc1997