peft
peft copied to clipboard
Autocast adapter weights if fp16/bf16
As discussed internally, we want to automatically cast the weights of the adapter to float32 when using float16. float16 is not conducive to stable training and raises errors when used with AMP.
Previously, we had to recommend to users to manually cast the weights if they loaded the base model in float16, because PEFT would choose the same dtype for the adapter as for the base weights. Forgetting this is a common source of errors, so we choose to automate this.
If this causes trouble, users can prevent the behavior by passing autocast_adapter_dtype=False to get_peft_model,
PeftModel.from_pretrained, or PeftModel.load_adapter.
I decided to implement the method that does the actual casting on the adapter model level, i.e. LoraModel et al., instead of the level of PeftModel. The idea was to be more flexible on how different tuner methods deal with the dtype casting.
Right now, I have implemented the casting only for LoraModel. If we think this should be applied more generally, we can also lift this to the level of BaseTuner, in which case LoHa, OFT, etc. would all inherit the same behavior. LMK what you think.
This PR should be reviewed carefully, as it has the potential to break existing code if something important was missed. We also need to add a note for the upcoming release text about this change in behavior.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@pacman100 Thanks for the review, I have addressed your comments
- Fixed the typo
- Also cast bfloat16
- Also cast for other tuners, not just LoRA (i.e. LoHa, OFT, etc.)
I tend to think we should do a release first and only then merge this PR. That way, it has a bit of time to be tested by folks using the main branch.
I tend to think we should do a release first and only then merge this PR. That way, it has a bit of time to be tested by folks using the main branch.
Agreed.
https://github.com/huggingface/peft/pull/1706#issuecomment-2101163837
Agreed with that too
Status: This is ready to be merged after the PEFT v0.11 release
Hi, I trained qwen2-1.5b model (with bfloat16 dtype) via Lora, I found the dtype of trained model (merged) becomes float32. Is this normal? How do I cast it back to bfloat16?
Could you please show the code to reproduce this?
Thank you for helping. I use the code base from open-instruct, the part of training codes are:
# construct the peft model
model = AutoModelForCausalLM.from_pretrained(
args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path),
config=config,
trust_remote_code=args.trust_remote_code,
low_cpu_mem_usage=args.low_cpu_mem_usage,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2" if args.use_flash_attn else "eager",
revision=args.model_revision,
token=os.getenv("HF_TOKEN", None),
)
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=args.lora_rank,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
target_modules=["q_proj", "o_proj", "v_proj", "k_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
save_with_accelerate(
accelerator,
model,
tokenizer,
args.output_dir,
args.use_lora,
)
def save_with_accelerate(
accelerator: Accelerator,
model: torch.nn.Module,
tokenizer: PreTrainedTokenizer,
output_dir: str,
use_lora: bool = False,
) -> None:
# set the generation config to an empty setting to be safe.
# we usually do greedy decoding for generation, so this should be okay.
# otherwise, we get an error thrown at save time.
model.generation_config = transformers.GenerationConfig(
temperature=None, top_p=None, eos_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id
)
unwrapped_model: PreTrainedModel = accelerator.unwrap_model(model)
# When doing multi-gpu training, we need to use accelerator.get_state_dict(model) to get the state_dict.
# Otherwise, sometimes the model will be saved with only part of the parameters.
# Also, accelerator needs to use the wrapped model to get the state_dict.
state_dict = accelerator.get_state_dict(model)
if use_lora:
# When using lora, the unwrapped model is a PeftModel, which doesn't support the is_main_process
# and has its own save_pretrained function for only saving lora modules.
# We have to manually specify the is_main_process outside the save_pretrained function.
if accelerator.is_main_process:
unwrapped_model.save_pretrained(output_dir, state_dict=state_dict)
else:
# don't use safetensors for saving for now
unwrapped_model.save_pretrained(
output_dir,
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=state_dict,
safe_serialization=False,
)
if accelerator.is_main_process:
tokenizer.save_pretrained(output_dir)
# merge base model and adapter
peft_config = PeftConfig.from_pretrained(args.lora_model_name_or_path)
base_model = AutoModelForCausalLM.from_pretrained(
args.base_model_name_or_path if args.base_model_name_or_path else peft_config.base_model_name_or_path,
)
print("Loading the lora model...")
lora_model = PeftModel.from_pretrained(base_model, args.lora_model_name_or_path)
print("Merging the lora modules...")
merged_model = lora_model.merge_and_unload()
output_dir = args.output_dir if args.output_dir else args.lora_model_name_or_path
os.makedirs(output_dir, exist_ok=True)
print(f"Saving merged model to {output_dir}...")
merged_model.save_pretrained(output_dir)
Codes above are from their open_instruct/finetune.py and open_instruct/merge_lora.py. Please tell me if anything is unclear.
When you load the base model before merging, i.e. here:
peft_config = PeftConfig.from_pretrained(args.lora_model_name_or_path)
base_model = AutoModelForCausalLM.from_pretrained(
args.base_model_name_or_path if args.base_model_name_or_path else peft_config.base_model_name_or_path,
)
you are not passing torch_dtype=torch.bfloat16 as you did the first time. Could this be the reason why you get float32 weights? Please always ensure that you load the base model in the same way.
Thanks for helping. It gets the right dtype now. I made a stupid mistake :sweat_smile: .