Deal with weight tying consistently
Currently, the way PEFT deals with tied embedding and LM head weights is not always clear. In #2803, a new argument, ensure_weight_tying, was introduced to make it easier for users to automatically tie the PEFT weights while keeping backwards compatibility. However, this makes it even more important to clarify what happens when.
The table below shows the intended behavior in different circumstances. Notably, weigh tying can effect modules_to_save, target_modules, and trainable_token_indices. The table lists the expected results for all combinations of these factors.
| weights tied | ensure_weight_tying | LoraConfig | result |
|---|---|---|---|
| False | False | modules_to_save=[embed_tokens] / modules_to_save=[lm_head] |
ModulesToSaveWrapper on embedding/lm head |
| False | True | modules_to_save=[embed_tokens] / modules_to_save=[lm_head] |
warn & ModulesToSaveWrapper on embedding/lm head |
| True | False | modules_to_save=[embed_tokens] / modules_to_save=[lm_head] |
ModulesToSaveWrapper on embedding/lm head (BC) |
| True | True | modules_to_save=[embed_tokens] / modules_to_save=[lm_head] |
ModulesToSaveWrappers share weights |
| False | False | modules_to_save=[embed_tokens, lm_head] |
treat as separate |
| False | True | modules_to_save=[embed_tokens, lm_head] |
warn & treat as separate |
| True | False | modules_to_save=[embed_tokens, lm_head] |
treat as separate (BC) |
| True | True | modules_to_save=[embed_tokens, lm_head] |
ModulesToSaveWrappers share weights |
| False | False | target_modules=[embed_tokens] / target_modules=[lm_head] |
LoRA on embedding/lm head |
| False | True | target_modules=[embed_tokens] / target_modules=[lm_head] |
*warn & LoRA on embedding/lm head |
| True | False | target_modules=[embed_tokens] / target_modules=[lm_head] |
LoRA on embedding/lm head (BC) |
| True | True | target_modules=[embed_tokens] / target_modules=[lm_head] |
*LoRA share weights |
| False | False | target_modules=[embed_tokens, lm_head] |
treat as separate |
| False | True | target_modules=[embed_tokens, lm_head] |
*warn & treat as separate |
| True | False | target_modules=[embed_tokens, lm_head] |
treat as separate (BC) |
| True | True | target_modules=[embed_tokens, lm_head] |
*LoRA share weights |
| False | False | trainable_token_indices=[1, 2, 3] |
trainable tokens on embeddings only |
| False | True | trainable_token_indices=[1, 2, 3] |
warn & trainable tokens on embeddings only |
| True | False | trainable_token_indices=[1, 2, 3] |
tied trainable tokens |
| True | True | trainable_token_indices=[1, 2, 3] |
tied trainable tokens |
| False | False | trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} |
treat as separate |
| False | True | trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} |
warn & treat as separate |
| True | False | trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} |
tied trainable tokens |
| True | True | trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} |
tied trainable tokens |
| False | False | trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} |
treat as separate |
| False | True | trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} |
warn & treat as separate |
| True | False | trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} |
*treat as separate |
| True | True | trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} |
*error |
Explanation:
- BC means that we keep this behavior for backwards compatibility, even if it might not be the most intuitive behavior.
- * marks behavior that is not yet implemented as such but should be added.
- For
trainable_token_indices, we distinguish between cases where embedding and LM head define the same indices, which would allow weight sharing, and where they define distinct indices, which precludes weight sharing.
Ping @romitjain
@BenjaminBossan for the following case
| weights tied | ensure_weight_tying | LoraConfig | result |
|---|---|---|---|
| True | False | modules_to_save=[embed_tokens, lm_head] | treat as separate (BC) |
| True | False | target_modules=[embed_tokens]Â /Â target_modules=[lm_head] | LoRA on embedding/lm head (BC) |
| True | False | target_modules=[embed_tokens, lm_head] | treat as separate (BC) |
Should we also add warnings in these three cases?
Should we also add warnings in these three cases?
For the target_modules case, we do have a warning, however only when the user wants to merge. One could argue that it would be nicer to have it already at PEFT model initialization time. When discussing this with @githubnemo, we concluded that it could be nice to have, but not as important as the other changes, so we omitted it for now.
this looks interesting . i would like to start implementing this out @BenjaminBossan if no one else has taken this up yet .
@sambhavnoobcoder I have already started work for target_modules. modules_to_save is already done. Can you take up for trainable_token_indices?
yeah sure @romitjain . i'll take that up .
Hi @BenjaminBossan @romitjain ,
I have raised a PR for the trainable_token_indices part of this in #2870 . i would appreciate if you could take a look and review the same .
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Not stale , active PR in review .
I noticed a potentially unexpected behavior when using TrainableTokensConfig with weight tying (weights tied=True, ensure_weight_tying=True).
-
When I set
trainable_token_indices = {"lm_head": [1, 2], "embed_tokens": [1, 2]}the resulting tensors
lm_head.token_adapter.trainable_tokens_delta.default.data_ptr()andembed_tokens.token_adapter.trainable_tokens_delta.default.data_ptr()do not match. -
But when I set
trainable_token_indices = [1, 2]the two modules correctly share the same underlying buffer (their
data_ptr()is identical), which matches the expected behavior under tied weights.
This feels counter-intuitive, since with tied weights I would expect the behavior of the list version (shared adapter weights) regardless of whether I provide a dict or a list. The dict form seems to break the tying assumption by creating separate deltas.
Not sure if this is intended or an edge case — just wanted to report the discrepancy in case it's a bug.
@Ambitious-idiot Thanks for the report. I wrote a small reproducer:
from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
# config = LoraConfig(trainable_token_indices=[1, 2]) # works
config = LoraConfig(trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [1, 2]}) # fails
model = get_peft_model(model, config)
w_head = model.base_model.model.lm_head.token_adapter.trainable_tokens_delta.default
w_embed = model.base_model.model.model.decoder.embed_tokens.token_adapter.trainable_tokens_delta.default
assert w_head.data_ptr() == w_embed.data_ptr()
Indeed, this fails, but it will pass once PR #2870 is merged.
@sambhavnoobcoder: On that PR, there is a corresponding test, test_ensure_weight_tying_applied_with_same_indices, but unless I'm missing something, there is no test case for trainable_token_indices=[1, 2]. Could you please parametrize said test to check both? Also, the outcome should be the same for both ensure_weight_tying=True and False, right? So let's parametrize over this parameter too.
Okay @BenjaminBossan , I'll make that change in that PR . Thank you @Ambitious-idiot for pointing this out.
Update : @BenjaminBossan added the changes in test in , checked it out , everything works and all tests pass .