blog
blog copied to clipboard
only classifier head is trained in tweet sentiment classification LoRA finetuning blog
@mehdiir,
We tried to reproduce your work in our env and found one weird issue: by using your code, gradient_checkpointing=True
runs much faster than gradient_checkpointing=False
which betrayed our intuition(2 hr vs 6 hr in our CPU env). So we did some analysis, as below:
-
In this case, while setting
gradient_checkpointing=True
(and with PyTorchuse_reentrant=True
implicitly), LoRA weights are wrapped by transformer block whose input and output'srequires_grad
are both False, so all the transformers blocks will not execute BP in this setting, so in this case, actually only classifier head is trained, LoRA weights will not be trained and keep as identity per initialization. -
We upgraded the transformers to 4.37.2 and add below 2 lines in
get_lora_model
to setuse_reentrant
to False, things will back to normal and LoRA weights will be trained.
def get_lora_model(model_checkpoints, num_labels=2, rank=4, alpha=16, lora_dropout=0.1, bias='none'):
...
+ gradient_checkpointing_kwargs = {"use_reentrant": False}
+ model.gradient_checkpointing_enable(gradient_checkpointing_kwargs = gradient_checkpointing_kwargs)
model = get_peft_model(model, peft_config)
print(model.print_trainable_parameters())
return model
FYI in case other people meet the similar issue too.