airllm icon indicating copy to clipboard operation
airllm copied to clipboard

unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit

Open kendiyang opened this issue 6 months ago • 2 comments

new version of transfomer, no need to use BetterTransformer, try setting attn impl to sdpa... attn imp: <class 'transformers.models.llama.modeling_llama.LlamaSdpaAttention'> Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>. running layers(cuda:0): 1%|▊ | 1/129 [00:03<07:01, 3.30s/it] Traceback (most recent call last): File "/root/test.py", line 18, in generation_output = model.generate( File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context return func(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 1989, in generate result = self._sample( File "/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py", line 2932, in _sample outputs = self(**model_inputs, return_dict=True) File "/usr/local/lib/python3.10/dist-packages/airllm/airllm_base.py", line 369, in call return self.forward(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/airllm/airllm_base.py", line 569, in forward new_seq = layer(seq, **kwargs)[0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 677, in forward hidden_states, self_attn_weights, present_key_value = self.self_attn( File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/transformers/models/llama/modeling_llama.py", line 565, in forward key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) RuntimeError: shape '[1, 9, 8, 128]' is invalid for input of size 18432

kendiyang avatar Sep 01 '24 07:09 kendiyang