bitsandbytes
bitsandbytes copied to clipboard
RuntimeError: mat1 and mat2 shapes cannot be multiplied (when using peft)
System Info
Ubuntu 20.04
Python 3.10.14
torch 2.3.0
transformers 4.42.3
bitsandbytes 0.42.0
CUDA Version: 12.4
GPU 3090
torch.cuda.is_available(): True
Reproduction
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig
import bitsandbytes as bnb
import peft
base_path = '/data/Qwen1.5-0.5B-Chat/'
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.float16,
)
llm = AutoModelForCausalLM.from_pretrained(
base_path, trust_remote_code=True, quantization_config=nf4_config)
tgt_mods = {
name.split('.')[-1]
for name, mod in llm.named_modules()
if isinstance(mod, bnb.nn.Linear4bit)
}
lora_config = peft.LoraConfig(
r=32,
lora_alpha=32,
target_modules=tgt_mods,
lora_dropout=0.1,
task_type=peft.TaskType.CAUSAL_LM,
inference_mode=False,
bias="none"
)
llm = peft.get_peft_model(llm, lora_config)
llm.cuda() # !!!
llm.chat(tok, '你好', max_length=4096)
'''
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[63], line 1
----> 1 llm.chat(tok, '你好', max_length=4096)
File ~/gpt_test/src/qwen_chat.py:8, in chat(model, tok, ques, history, **kw)
3 def chat(model, tok, ques, history=[], **kw):
4 iids = tok.apply_chat_template(
5 history + [{'role': 'user', 'content': ques}],
6 add_generation_prompt=1,
7 )
----> 8 oids = model.generate(
9 inputs=torch.tensor([iids]).to(model.device),
10 **(model.generation_config.to_dict() | kw),
11 )
12 oids = oids[0][len(iids):].tolist()
13 if oids[-1] == tok.eos_token_id:
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/peft/peft_model.py:1190, in PeftModelForCausalLM.generate(self, *args, **kwargs)
1188 with self._enable_peft_forward_hooks(*args, **kwargs):
1189 kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1190 outputs = self.base_model.generate(*args, **kwargs)
1191 else:
1192 outputs = self.base_model.generate(**kwargs)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/generation/utils.py:1914, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
1906 input_ids, model_kwargs = self._expand_inputs_for_generation(
1907 input_ids=input_ids,
1908 expand_size=generation_config.num_return_sequences,
1909 is_encoder_decoder=self.config.is_encoder_decoder,
1910 **model_kwargs,
1911 )
1913 # 13. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 1914 result = self._sample(
1915 input_ids,
1916 logits_processor=prepared_logits_processor,
1917 logits_warper=prepared_logits_warper,
1918 stopping_criteria=prepared_stopping_criteria,
1919 generation_config=generation_config,
1920 synced_gpus=synced_gpus,
1921 streamer=streamer,
1922 **model_kwargs,
1923 )
1925 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
1926 # 11. prepare logits warper
1927 prepared_logits_warper = (
1928 self._get_logits_warper(generation_config, device=input_ids.device)
1929 if generation_config.do_sample
1930 else None
1931 )
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/generation/utils.py:2651, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
2648 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2650 # forward pass to get next token
-> 2651 outputs = self(
2652 **model_inputs,
2653 return_dict=True,
2654 output_attentions=output_attentions,
2655 output_hidden_states=output_hidden_states,
2656 )
2658 if synced_gpus and this_peer_finished:
2659 continue # don't waste resources running the code we don't need
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1221, in Qwen2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1218 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1220 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1221 outputs = self.model(
1222 input_ids=input_ids,
1223 attention_mask=attention_mask,
1224 position_ids=position_ids,
1225 past_key_values=past_key_values,
1226 inputs_embeds=inputs_embeds,
1227 use_cache=use_cache,
1228 output_attentions=output_attentions,
1229 output_hidden_states=output_hidden_states,
1230 return_dict=return_dict,
1231 cache_position=cache_position,
1232 )
1234 hidden_states = outputs[0]
1235 logits = self.lm_head(hidden_states)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:1023, in Qwen2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1012 layer_outputs = self._gradient_checkpointing_func(
1013 decoder_layer.__call__,
1014 hidden_states,
(...)
1020 cache_position,
1021 )
1022 else:
-> 1023 layer_outputs = decoder_layer(
1024 hidden_states,
1025 attention_mask=causal_mask,
1026 position_ids=position_ids,
1027 past_key_value=past_key_values,
1028 output_attentions=output_attentions,
1029 use_cache=use_cache,
1030 cache_position=cache_position,
1031 )
1033 hidden_states = layer_outputs[0]
1035 if use_cache:
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:763, in Qwen2DecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, **kwargs)
760 hidden_states = self.input_layernorm(hidden_states)
762 # Self Attention
--> 763 hidden_states, self_attn_weights, present_key_value = self.self_attn(
764 hidden_states=hidden_states,
765 attention_mask=attention_mask,
766 position_ids=position_ids,
767 past_key_value=past_key_value,
768 output_attentions=output_attentions,
769 use_cache=use_cache,
770 cache_position=cache_position,
771 )
772 hidden_states = residual + hidden_states
774 # Fully Connected
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:650, in Qwen2SdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
639 return super().forward(
640 hidden_states=hidden_states,
641 attention_mask=attention_mask,
(...)
645 use_cache=use_cache,
646 )
648 bsz, q_len, _ = hidden_states.size()
--> 650 query_states = self.q_proj(hidden_states)
651 key_states = self.k_proj(hidden_states)
652 value_states = self.v_proj(hidden_states)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/peft/tuners/lora/bnb.py:452, in Linear4bit.forward(self, x, *args, **kwargs)
450 result = self.base_layer(x, *args, **kwargs)
451 else:
--> 452 result = self.base_layer(x, *args, **kwargs)
453 # As per Tim Dettmers, for 4bit, we need to defensively clone here.
454 # The reason is that in some cases, an error can occur that backprop
455 # does not work on a manipulated view. This issue may be solved with
456 # newer PyTorch versions but this would need extensive testing to be
457 # sure.
458 result = result.clone()
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/accelerate/hooks.py:166, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
164 output = module._old_forward(*args, **kwargs)
165 else:
--> 166 output = module._old_forward(*args, **kwargs)
167 return module._hf_hook.post_forward(module, output)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/bitsandbytes/nn/modules.py:256, in Linear4bit.forward(self, x)
253 x = x.to(self.compute_dtype)
255 bias = None if self.bias is None else self.bias.to(self.compute_dtype)
--> 256 out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
258 out = out.to(inp_dtype)
260 return out
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:577, in matmul_4bit(A, B, quant_state, out, bias)
575 return out
576 else:
--> 577 return MatMul4Bit.apply(A, B, out, bias, quant_state)
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/torch/autograd/function.py:598, in Function.apply(cls, *args, **kwargs)
595 if not torch._C._are_functorch_transforms_active():
596 # See NOTE: [functorch vjp and autograd interaction]
597 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 598 return super().apply(*args, **kwargs) # type: ignore[misc]
600 if not is_setup_ctx_defined:
601 raise RuntimeError(
602 "In order to use an autograd.Function with functorch transforms "
603 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
604 "staticmethod. For more details, please see "
605 "https://pytorch.org/docs/master/notes/extending.func.html"
606 )
File ~/anaconda3/envs/bnb/lib/python3.10/site-packages/bitsandbytes/autograd/_functions.py:516, in MatMul4Bit.forward(ctx, A, B, out, bias, quant_state)
511 return torch.empty(A.shape[:-1] + B_shape[:1], dtype=A.dtype, device=A.device)
514 # 1. Dequantize
515 # 2. MatmulnN
--> 516 output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
518 # 3. Save state
519 ctx.state = quant_state
RuntimeError: mat1 and mat2 shapes cannot be multiplied (19x1024 and 1x524288)
'''
Expected behavior
.