merge_adapter does not work with bitsandbytes 4bit quantized model.
System Info
peft==0.15.1
Who can help?
No response
Information
- [ ] The official example scripts
- [x] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder - [ ] My own task or dataset (give details below)
Reproduction
model_name = "meta-llama/Llama-3.1-8B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
)
peft_config = LoraConfig(
lora_alpha=16,
lora_dropout=0.1,
r=8,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear"
)
self.model = get_peft_model(model, peft_config)
self.model.merge_adapter()
File "/trl/trl/trainer/grpo_trainer.py", line 662, in _move_model_to_vllm self.model.merge_adapter() File "/peft/tuners/tuners_utils.py", line 595, in merge_adapter module.merge(adapter_names=adapter_names) File "/peft/tuners/lora/bnb.py", line 387, in merge self.get_base_layer().weight = bnb.nn.Params4bit(w_data.to("cpu"), **kwargs).to(weight.device) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ TypeError: Params4bit.new() got an unexpected keyword argument 'ds_param_type'
I am trying to merge a Lora adapter back into a 4bit quantized model with the possibility to unmerge again on a later point.
It works if I use temp_model = self.model.merge_and_unload()
Expected behavior
Adapter gets merged into the 4-bit model.
Thanks for the report. Merging works for me, it's most likely an issue with a package version. Which versions of torch, bitsandbytes and transformers do you use? If not the latest, could you please try upgrading them and then run your script once more?
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.