peft
peft copied to clipboard
FIX Store batch norm buffers in PEFT checkpoint
Fixes #1732
After loading a model that was trained with PEFT on a base model with some kind of batch norm layer, the loaded model should produce the same output. Right now, this does not happen.
The reason is that during training, buffers for running mean etc. are updated, but they are not saved when calling save_pretrained
on the PeftModel
instance. Normally in PEFT, we assume that during training, the base model parameters are kept constant, which is not the case with batch norm. We only save the PEFT parameters and assume that when the user loads the base model, all parameters are restored exactly. That way, the information in the buffers is lost completely.
This PR fixes this issue by saving the buffers of the batch norm layers. They are identified by checking for the presence of the track_running_stats
attribute.
Note: One test for BOFT is currently failing, see the comment in the test file.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@BenjaminBossan when merging the weights... I suppose currently it will work for nn.Params
but now it also needs to do the same for the buffers... how is that handled?
when merging the weights... I suppose currently it will work for
nn.Params
but now it also needs to do the same for the buffers... how is that handled?
Hmm, not sure if I understand. If I train a LoRA adapter and then merge it, its weights will be fused with the base model weights. When we load the LoRA adapter, the base model's running stats are replaced by whatever is saved in the LoRA checkpoint. As the running stats buffers are part of the base model and no LoRA is applied to them, they are not further affected by merging.
Based on your comment, I could think of another problematic case though: A user adds 2 adapters, first adapter A
, then B
. Let's call the running stats buffers rs
. They train A
first, updating the running stats to rs_A
. Then they switch to B
and train, which updates the running stats further to rs_A_B
. When they now safe the adapter, we will store rs_A_B
for both adapters, when in reality we want rs_A
and rs_B
.
There are probably more edge cases that result from the fact that we kinda assume that only the PEFT parameters ever change, whereas the base model parameters are fixed. I think for this particular one, we can accept this failure case for now, as the scenario I describe should be very rare (users would typically train A
and B
separately, not in the same process).
Oh, now I wonder if there isn't a different solution: Ask users to add the batch norm layers to modules_to_save
. I haven't thought this through completely, but this may solve all the problems we discussed with batch norm. The disadvantage is that users need to explicitly set modules_to_save
.
~~yes that was my initial solution i believe...~~ sorry i was thinking of something else
Just tested the modules_to_save
solution: When using this config: LoraConfig(target_modules=["convolution"], modules_to_save=["classifier", "normalization"])
, the test passes even without changing the code to explicitly add buffers to the state_dict
.
@pacman100 I removed the new functionality and instead changed the test to add batch norm layers to modules_to_save
and the tests pass. I also added a troubleshooting section about this.