peft
peft copied to clipboard
`load_adapter` seems to require the base model to be identical
System Info
peft: 0.11.1, accelerate: 0.30.1, transformers: 4.41.2, trl: 0.9.4, python: 3.10.14
Who can help?
No response
Information
- [ ] The official example scripts
- [ ] My own modified scripts
Tasks
- [ ] An officially supported task in the
examplesfolder - [ ] My own task or dataset (give details below)
Reproduction
model, tokenizer = setup_chat_format(model, tokenizer) #this line adds addition token embedding
model.load_adapter(peft_model_id, adapter_name="default") #peft_model_id points to the same model before its embedding layer was resized
Expected behavior
https://github.com/huggingface/peft/blob/52684952136b3dcc1120834f8af2d8f5aaa1c16a/src/peft/peft_model.py#L1028 as the title suggests. I tried to use model.load_adapter(peft_model_id, adapter_name="default") to load adapter from a peft model that has different dimensions for the token embedding, and got the following message:
<_IncompatibleKeys, len() = 2>
_IncompatibleKeys(missing_keys=['base_model.model.model.embed_tokens.weight', 'base_model.model.model.layers.0.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.0.self_attn.k_proj.weight', 'base_model.model.model.layers.0.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.0.self_attn.o_proj.weight', 'base_model.model.model.layers.0.mlp.gate_proj.weight', 'base_model.model.model.layers.0.mlp.up_proj.weight', 'base_model.model.model.layers.0.mlp.down_proj.weight', 'base_model.model.model.layers.0.input_layernorm.weight', 'base_model.model.model.layers.0.post_attention_layernorm.weight', 'base_model.model.model.layers.1.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.1.self_attn.k_proj.weight', 'base_model.model.model.layers.1.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.1.self_attn.o_proj.weight', 'base_model.model.model.layers.1.mlp.gate_proj.weight', 'base_model.model.model.layers.1.mlp.up_proj.weight', 'base_model.model.model.layers.1.mlp.down_proj.weight', 'base_model.model.model.layers.1.input_layernorm.weight', 'base_model.model.model.layers.1.post_attention_layernorm.weight', 'base_model.model.model.layers.2.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.2.self_attn.k_proj.weight', 'base_model.model.model.layers.2.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.2.self_attn.o_proj.weight', 'base_model.model.model.layers.2.mlp.gate_proj.weight', 'base_model.model.model.layers.2.mlp.up_proj.weight', 'base_model.model.model.layers.2.mlp.down_proj.weight', 'base_model.model.model.layers.2.input_layernorm.weight', 'base_model.model.model.layers.2.post_attention_layernorm.weight', 'base_model.model.model.layers.3.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.3.self_attn.k_proj.weight', 'base_model.model.model.layers.3.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.3.self_attn.o_proj.weight', 'base_model.model.model.layers.3.mlp.gate_proj.weight', 'base_model.model.model.layers.3.mlp.up_proj.weight', 'base_model.model.model.layers.3.mlp.down_proj.weight', 'base_model.model.model.layers.3.input_layernorm.weight', 'base_model.model.model.layers.3.post_attention_layernorm.weight', 'base_model.model.model.layers.4.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.4.self_attn.k_proj.weight', 'base_model.model.model.layers.4.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.4.self_attn.o_proj.weight', 'base_model.model.model.layers.4.mlp.gate_proj.weight', 'base_model.model.model.layers.4.mlp.up_proj.weight', 'base_model.model.model.layers.4.mlp.down_proj.weight', 'base_model.model.model.layers.4.input_layernorm.weight', 'base_model.model.model.layers.4.post_attention_layernorm.weight', 'base_model.model.model.layers.5.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.5.self_attn.k_proj.weight', 'base_model.model.model.layers.5.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.5.self_attn.o_proj.weight', 'base_model.model.model.layers.5.mlp.gate_proj.weight', 'base_model.model.model.layers.5.mlp.up_proj.weight', 'base_model.model.model.layers.5.mlp.down_proj.weight', 'base_model.model.model.layers.5.input_layernorm.weight', 'base_model.model.model.layers.5.post_attention_layernorm.weight', 'base_model.model.model.layers.6.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.6.self_attn.k_proj.weight', 'base_model.model.model.layers.6.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.6.self_attn.o_proj.weight', 'base_model.model.model.layers.6.mlp.gate_proj.weight', 'base_model.model.model.layers.6.mlp.up_proj.weight', 'base_model.model.model.layers.6.mlp.down_proj.weight', 'base_model.model.model.layers.6.input_layernorm.weight', 'base_model.model.model.layers.6.post_attention_layernorm.weight', 'base_model.model.model.layers.7.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.7.self_attn.k_proj.weight', 'base_model.model.model.layers.7.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.7.self_attn.o_proj.weight', 'base_model.model.model.layers.7.mlp.gate_proj.weight', 'base_model.model.model.layers.7.mlp.up_proj.weight', 'base_model.model.model.layers.7.mlp.down_proj.weight', 'base_model.model.model.layers.7.input_layernorm.weight', 'base_model.model.model.layers.7.post_attention_layernorm.weight', 'base_model.model.model.layers.8.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.8.self_attn.k_proj.weight', 'base_model.model.model.layers.8.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.8.self_attn.o_proj.weight', 'base_model.model.model.layers.8.mlp.gate_proj.weight', 'base_model.model.model.layers.8.mlp.up_proj.weight', 'base_model.model.model.layers.8.mlp.down_proj.weight', 'base_model.model.model.layers.8.input_layernorm.weight', 'base_model.model.model.layers.8.post_attention_layernorm.weight', 'base_model.model.model.layers.9.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.9.self_attn.k_proj.weight', 'base_model.model.model.layers.9.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.9.self_attn.o_proj.weight', 'base_model.model.model.layers.9.mlp.gate_proj.weight', 'base_model.model.model.layers.9.mlp.up_proj.weight', 'base_model.model.model.layers.9.mlp.down_proj.weight', 'base_model.model.model.layers.9.input_layernorm.weight', 'base_model.model.model.layers.9.post_attention_layernorm.weight', 'base_model.model.model.layers.10.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.10.self_attn.k_proj.weight', 'base_model.model.model.layers.10.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.10.self_attn.o_proj.weight', 'base_model.model.model.layers.10.mlp.gate_proj.weight', 'base_model.model.model.layers.10.mlp.up_proj.weight', 'base_model.model.model.layers.10.mlp.down_proj.weight', 'base_model.model.model.layers.10.input_layernorm.weight', 'base_model.model.model.layers.10.post_attention_layernorm.weight', 'base_model.model.model.layers.11.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.11.self_attn.k_proj.weight', 'base_model.model.model.layers.11.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.11.self_attn.o_proj.weight', 'base_model.model.model.layers.11.mlp.gate_proj.weight', 'base_model.model.model.layers.11.mlp.up_proj.weight', 'base_model.model.model.layers.11.mlp.down_proj.weight', 'base_model.model.model.layers.11.input_layernorm.weight', 'base_model.model.model.layers.11.post_attention_layernorm.weight', 'base_model.model.model.layers.12.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.12.self_attn.k_proj.weight', 'base_model.model.model.layers.12.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.12.self_attn.o_proj.weight', 'base_model.model.model.layers.12.mlp.gate_proj.weight', 'base_model.model.model.layers.12.mlp.up_proj.weight', 'base_model.model.model.layers.12.mlp.down_proj.weight', 'base_model.model.model.layers.12.input_layernorm.weight', 'base_model.model.model.layers.12.post_attention_layernorm.weight', 'base_model.model.model.layers.13.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.13.self_attn.k_proj.weight', 'base_model.model.model.layers.13.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.13.self_attn.o_proj.weight', 'base_model.model.model.layers.13.mlp.gate_proj.weight', 'base_model.model.model.layers.13.mlp.up_proj.weight', 'base_model.model.model.layers.13.mlp.down_proj.weight', 'base_model.model.model.layers.13.input_layernorm.weight', 'base_model.model.model.layers.13.post_attention_layernorm.weight', 'base_model.model.model.layers.14.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.14.self_attn.k_proj.weight', 'base_model.model.model.layers.14.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.14.self_attn.o_proj.weight', 'base_model.model.model.layers.14.mlp.gate_proj.weight', 'base_model.model.model.layers.14.mlp.up_proj.weight', 'base_model.model.model.layers.14.mlp.down_proj.weight', 'base_model.model.model.layers.14.input_layernorm.weight', 'base_model.model.model.layers.14.post_attention_layernorm.weight', 'base_model.model.model.layers.15.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.15.self_attn.k_proj.weight', 'base_model.model.model.layers.15.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.15.self_attn.o_proj.weight', 'base_model.model.model.layers.15.mlp.gate_proj.weight', 'base_model.model.model.layers.15.mlp.up_proj.weight', 'base_model.model.model.layers.15.mlp.down_proj.weight', 'base_model.model.model.layers.15.input_layernorm.weight', 'base_model.model.model.layers.15.post_attention_layernorm.weight', 'base_model.model.model.layers.16.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.16.self_attn.k_proj.weight', 'base_model.model.model.layers.16.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.16.self_attn.o_proj.weight', 'base_model.model.model.layers.16.mlp.gate_proj.weight', 'base_model.model.model.layers.16.mlp.up_proj.weight', 'base_model.model.model.layers.16.mlp.down_proj.weight', 'base_model.model.model.layers.16.input_layernorm.weight', 'base_model.model.model.layers.16.post_attention_layernorm.weight', 'base_model.model.model.layers.17.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.17.self_attn.k_proj.weight', 'base_model.model.model.layers.17.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.17.self_attn.o_proj.weight', 'base_model.model.model.layers.17.mlp.gate_proj.weight', 'base_model.model.model.layers.17.mlp.up_proj.weight', 'base_model.model.model.layers.17.mlp.down_proj.weight', 'base_model.model.model.layers.17.input_layernorm.weight', 'base_model.model.model.layers.17.post_attention_layernorm.weight', 'base_model.model.model.layers.18.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.18.self_attn.k_proj.weight', 'base_model.model.model.layers.18.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.18.self_attn.o_proj.weight', 'base_model.model.model.layers.18.mlp.gate_proj.weight', 'base_model.model.model.layers.18.mlp.up_proj.weight', 'base_model.model.model.layers.18.mlp.down_proj.weight', 'base_model.model.model.layers.18.input_layernorm.weight', 'base_model.model.model.layers.18.post_attention_layernorm.weight', 'base_model.model.model.layers.19.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.19.self_attn.k_proj.weight', 'base_model.model.model.layers.19.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.19.self_attn.o_proj.weight', 'base_model.model.model.layers.19.mlp.gate_proj.weight', 'base_model.model.model.layers.19.mlp.up_proj.weight', 'base_model.model.model.layers.19.mlp.down_proj.weight', 'base_model.model.model.layers.19.input_layernorm.weight', 'base_model.model.model.layers.19.post_attention_layernorm.weight', 'base_model.model.model.layers.20.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.20.self_attn.k_proj.weight', 'base_model.model.model.layers.20.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.20.self_attn.o_proj.weight', 'base_model.model.model.layers.20.mlp.gate_proj.weight', 'base_model.model.model.layers.20.mlp.up_proj.weight', 'base_model.model.model.layers.20.mlp.down_proj.weight', 'base_model.model.model.layers.20.input_layernorm.weight', 'base_model.model.model.layers.20.post_attention_layernorm.weight', 'base_model.model.model.layers.21.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.21.self_attn.k_proj.weight', 'base_model.model.model.layers.21.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.21.self_attn.o_proj.weight', 'base_model.model.model.layers.21.mlp.gate_proj.weight', 'base_model.model.model.layers.21.mlp.up_proj.weight', 'base_model.model.model.layers.21.mlp.down_proj.weight', 'base_model.model.model.layers.21.input_layernorm.weight', 'base_model.model.model.layers.21.post_attention_layernorm.weight', 'base_model.model.model.layers.22.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.22.self_attn.k_proj.weight', 'base_model.model.model.layers.22.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.22.self_attn.o_proj.weight', 'base_model.model.model.layers.22.mlp.gate_proj.weight', 'base_model.model.model.layers.22.mlp.up_proj.weight', 'base_model.model.model.layers.22.mlp.down_proj.weight', 'base_model.model.model.layers.22.input_layernorm.weight', 'base_model.model.model.layers.22.post_attention_layernorm.weight', 'base_model.model.model.layers.23.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.23.self_attn.k_proj.weight', 'base_model.model.model.layers.23.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.23.self_attn.o_proj.weight', 'base_model.model.model.layers.23.mlp.gate_proj.weight', 'base_model.model.model.layers.23.mlp.up_proj.weight', 'base_model.model.model.layers.23.mlp.down_proj.weight', 'base_model.model.model.layers.23.input_layernorm.weight', 'base_model.model.model.layers.23.post_attention_layernorm.weight', 'base_model.model.model.layers.24.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.24.self_attn.k_proj.weight', 'base_model.model.model.layers.24.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.24.self_attn.o_proj.weight', 'base_model.model.model.layers.24.mlp.gate_proj.weight', 'base_model.model.model.layers.24.mlp.up_proj.weight', 'base_model.model.model.layers.24.mlp.down_proj.weight', 'base_model.model.model.layers.24.input_layernorm.weight', 'base_model.model.model.layers.24.post_attention_layernorm.weight', 'base_model.model.model.layers.25.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.25.self_attn.k_proj.weight', 'base_model.model.model.layers.25.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.25.self_attn.o_proj.weight', 'base_model.model.model.layers.25.mlp.gate_proj.weight', 'base_model.model.model.layers.25.mlp.up_proj.weight', 'base_model.model.model.layers.25.mlp.down_proj.weight', 'base_model.model.model.layers.25.input_layernorm.weight', 'base_model.model.model.layers.25.post_attention_layernorm.weight', 'base_model.model.model.layers.26.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.26.self_attn.k_proj.weight', 'base_model.model.model.layers.26.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.26.self_attn.o_proj.weight', 'base_model.model.model.layers.26.mlp.gate_proj.weight', 'base_model.model.model.layers.26.mlp.up_proj.weight', 'base_model.model.model.layers.26.mlp.down_proj.weight', 'base_model.model.model.layers.26.input_layernorm.weight', 'base_model.model.model.layers.26.post_attention_layernorm.weight', 'base_model.model.model.layers.27.self_attn.q_proj.base_layer.weight', 'base_model.model.model.layers.27.self_attn.k_proj.weight', 'base_model.model.model.layers.27.self_attn.v_proj.base_layer.weight', 'base_model.model.model.layers.27.self_attn.o_proj.weight', 'base_model.model.model.layers.27.mlp.gate_proj.weight', 'base_model.model.model.layers.27.mlp.up_proj.weight', 'base_model.model.model.layers.27.mlp.down_proj.weight', 'base_model.model.model.layers.27.input_layernorm.weight', 'base_model.model.model.layers.27.post_attention_layernorm.weight', 'base_model.model.model.norm.weight', 'base_model.model.lm_head.weight'], unexpected_keys=[])
Is it intended? I assume the expected behavior of model.load_adapter(peft_model_id, adapter_name="default") is only to get the lora parameter. Why does it care if the base models are identical or not?
TLDR: You can just ignore this.
Long story: This is a bit of an older annoyance that we should eventually fix.
Under the hood, PEFT uses the load_state_dict method from PyTorch, which automatically creates this output. Since for a PEFT model, we only load a subset of parameters, PyTorch thinks that the state_dict could be missing some items, so it returns the names of missing keys. In PEFT, we don't do anything further except for returning this result to the user.
We could instead just ignore this and not show anything to the user to avoid potential confusion. I think this would be reasonable, but there is a small chance that something really went wrong, in which case users can debug more easily if they can see the list of missing/unexpected keys.
When calling from_pretrained, this list is also created, but since we return the model in this method, the user never gets to the the _IncompatibleKeys, which is inconsistent.
Probably the correct way to handle this would be to filter the missing keys to exclude everything not related to the adapter that is currently being loaded. If, after this, there are still some missing keys, we could warn the user about this. This should be consistently done for load_adapter and from_pretrained.
I put this issue on the PEFT backlog. If anyone wants to contribute this enhancement, feel free to open a PR.
@BenjaminBossan I see. Appreciate the in-depth explanation!
Hi @BenjaminBossan , If this Feature is still a priority then, I want to work on it.
@yaswanth19 Thanks for the offer, this is still open, so feel free to give it a go. Don't hesitate to ask if you have questions or to create an early draft PR for quick feedback.
Yeah sure Thanks!!
Hi @BenjaminBossan, Here is the approach I am thinking:
Modify the load_adapter function to filter out non-adapter-related keys . After loading the state dict, capture the missing_keys and unexpected_keys, filter them to only include adapter-relevant keys(use target modules to get those keys), and raise warnings only if relevant keys of adapter is missing. If no adapter-related keys are affected, suppress the warning.
Does this approach sounds right or Is there anything else I am missing.
This sounds good, thanks. If you have some code already, don't hesitate to create an early draft PR, even without tests etc., so that I can give quick feedback.
@BenjaminBossan
I tried out few ideas:
1.) Use the check_target_module_exists function on the missing keys as we have different ways to say on which modules to apply the method.
But stuck as we can have layer name encoder.layer0.self.query.lora_A.weight and it doesn't satisfy the condition of "end with" in check_target_module_exists function when we mention target_modules=['query']. Adding some more rules for each adapter would make it more error prone.
2.) Using tuner name prefix to check whether any adapter related key is present in missing keys.
It's not always the case where we have tuner name added as a prefix to layer name like for LoHa.
Confused on how to proceed further as other thing which we can do is mimic the process in which we inject the adapter but in reverse and find the relevant adapter keys but I think it would too complex. Is there anything we could do from the result of missing and unexpected keys.
Thanks for the update @yaswanth19.
Indeed I think approach 1 will be difficult to make work. Regarding approach 2, could you give an example of adapter related keys that are missing the prefix? E.g. for LoHa, the prefix is defined as hada_ here and AFAICT, all LoHa parameters use the prefix.
I think this approach would still not be guaranteed to work perfectly, since 1) the existence of a prefix is not a guarantee that the parameter belongs to the adapter and 2) not all PEFT methods use a prefix at all (LayerNorm tuning does not add extra parameters, for instance). However, we don't need to find a perfect solution. The current situation is imperfect, so if we can make it better, even if still imperfect, it would be an improvement.
@BenjaminBossan
Right one thing we can do is to create a mapping between adapter name and prefix, and check for prefix in missing keys.
One case I observed for Vera tuner with bert is when saving the adapter weights only vera_lambda_b and vera_lamda_d were getting stored but the model had vera_A and vera_B layers. I don't know why it is the case but this could cause false positives when using prefix matching.
Good point about VeRA. The point of VeRA is that the vera_A and vera_B layers are shared among all layers, so that they don't need to be stored for each layer separately (optionally, they might not even be stored at all and just created based on a random seed).
I think there are enough edge cases that we should not raise an error or warn when we find missing keys (i.e. stick with the status quo). However, it would still be helpful to filter out false positives from the missing_keys as best as possible to avoid misleading situations like the one reported by OP, which should still cover most cases.
We can close this issue as the PR is merged.