new version of transfomer, no need to use BetterTransformer, try setting attn impl to sdpa...
attn imp: <class 'transformers.models.llama.modeling_llama.LlamaSdpaAttention'>
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
running layers(cuda:0): 1%|▊ | 1/129 [00:03<07:01, 3.30s/it]
Traceback (most recent call last):
File "/root/test.py", line 18, in
generation_output = model.generate(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1989, in generate
result = self._sample(
File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2932, in _sample
outputs = self(**model_inputs, return_dict=True)
File "/usr/local/lib/python3.10/dist-packages/airllm/airllm_base.py", line 369, in call
return self.forward(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/airllm/airllm_base.py", line 569, in forward
new_seq = layer(seq, **kwargs)[0]
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 677, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 565, in forward
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
RuntimeError: shape '[1, 9, 8, 128]' is invalid for input of size 18432