peft icon indicating copy to clipboard operation
peft copied to clipboard

Deal with weight tying consistently

Open BenjaminBossan opened this issue 2 months ago • 6 comments

Currently, the way PEFT deals with tied embedding and LM head weights is not always clear. In #2803, a new argument, ensure_weight_tying, was introduced to make it easier for users to automatically tie the PEFT weights while keeping backwards compatibility. However, this makes it even more important to clarify what happens when.

The table below shows the intended behavior in different circumstances. Notably, weigh tying can effect modules_to_save, target_modules, and trainable_token_indices. The table lists the expected results for all combinations of these factors.

weights tied ensure_weight_tying LoraConfig result
False False modules_to_save=[embed_tokens] / modules_to_save=[lm_head] ModulesToSaveWrapper on embedding/lm head
False True modules_to_save=[embed_tokens] / modules_to_save=[lm_head] warn & ModulesToSaveWrapper on embedding/lm head
True False modules_to_save=[embed_tokens] / modules_to_save=[lm_head] ModulesToSaveWrapper on embedding/lm head (BC)
True True modules_to_save=[embed_tokens] / modules_to_save=[lm_head] ModulesToSaveWrappers share weights
False False modules_to_save=[embed_tokens, lm_head] treat as separate
False True modules_to_save=[embed_tokens, lm_head] warn & treat as separate
True False modules_to_save=[embed_tokens, lm_head] treat as separate (BC)
True True modules_to_save=[embed_tokens, lm_head] ModulesToSaveWrappers share weights
False False target_modules=[embed_tokens] / target_modules=[lm_head] LoRA on embedding/lm head
False True target_modules=[embed_tokens] / target_modules=[lm_head] *warn & LoRA on embedding/lm head
True False target_modules=[embed_tokens] / target_modules=[lm_head] LoRA on embedding/lm head (BC)
True True target_modules=[embed_tokens] / target_modules=[lm_head] *LoRA share weights
False False target_modules=[embed_tokens, lm_head] treat as separate
False True target_modules=[embed_tokens, lm_head] *warn & treat as separate
True False target_modules=[embed_tokens, lm_head] treat as separate (BC)
True True target_modules=[embed_tokens, lm_head] *LoRA share weights
False False trainable_token_indices=[1, 2, 3] trainable tokens on embeddings only
False True trainable_token_indices=[1, 2, 3] warn & trainable tokens on embeddings only
True False trainable_token_indices=[1, 2, 3] tied trainable tokens
True True trainable_token_indices=[1, 2, 3] tied trainable tokens
False False trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} treat as separate
False True trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} warn & treat as separate
True False trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} tied trainable tokens
True True trainable_token_indices={"lm_head": [1,2], "embed_tokens": [1,2]} tied trainable tokens
False False trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} treat as separate
False True trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} warn & treat as separate
True False trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} *treat as separate
True True trainable_token_indices={"lm_head": [1,2], "embed_tokens": [3,4]} *error

Explanation:

  • BC means that we keep this behavior for backwards compatibility, even if it might not be the most intuitive behavior.
  • * marks behavior that is not yet implemented as such but should be added.
  • For trainable_token_indices, we distinguish between cases where embedding and LM head define the same indices, which would allow weight sharing, and where they define distinct indices, which precludes weight sharing.

Ping @romitjain

BenjaminBossan avatar Oct 23 '25 13:10 BenjaminBossan

@BenjaminBossan for the following case

weights tied ensure_weight_tying LoraConfig result
True False modules_to_save=[embed_tokens, lm_head] treat as separate (BC)
True False target_modules=[embed_tokens] / target_modules=[lm_head] LoRA on embedding/lm head (BC)
True False target_modules=[embed_tokens, lm_head] treat as separate (BC)

Should we also add warnings in these three cases?

romitjain avatar Oct 23 '25 13:10 romitjain

Should we also add warnings in these three cases?

For the target_modules case, we do have a warning, however only when the user wants to merge. One could argue that it would be nicer to have it already at PEFT model initialization time. When discussing this with @githubnemo, we concluded that it could be nice to have, but not as important as the other changes, so we omitted it for now.

BenjaminBossan avatar Oct 23 '25 14:10 BenjaminBossan

this looks interesting . i would like to start implementing this out @BenjaminBossan if no one else has taken this up yet .

sambhavnoobcoder avatar Oct 24 '25 12:10 sambhavnoobcoder

@sambhavnoobcoder I have already started work for target_modules. modules_to_save is already done. Can you take up for trainable_token_indices?

romitjain avatar Oct 24 '25 12:10 romitjain

yeah sure @romitjain . i'll take that up .

sambhavnoobcoder avatar Oct 24 '25 13:10 sambhavnoobcoder

Hi @BenjaminBossan @romitjain , I have raised a PR for the trainable_token_indices part of this in #2870 . i would appreciate if you could take a look and review the same .

sambhavnoobcoder avatar Oct 26 '25 21:10 sambhavnoobcoder

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

github-actions[bot] avatar Nov 22 '25 15:11 github-actions[bot]

Not stale , active PR in review .

sambhavnoobcoder avatar Nov 22 '25 15:11 sambhavnoobcoder

I noticed a potentially unexpected behavior when using TrainableTokensConfig with weight tying (weights tied=True, ensure_weight_tying=True).

  • When I set

    trainable_token_indices = {"lm_head": [1, 2], "embed_tokens": [1, 2]}
    

    the resulting tensors lm_head.token_adapter.trainable_tokens_delta.default.data_ptr() and embed_tokens.token_adapter.trainable_tokens_delta.default.data_ptr() do not match.

  • But when I set

    trainable_token_indices = [1, 2]
    

    the two modules correctly share the same underlying buffer (their data_ptr() is identical), which matches the expected behavior under tied weights.

This feels counter-intuitive, since with tied weights I would expect the behavior of the list version (shared adapter weights) regardless of whether I provide a dict or a list. The dict form seems to break the tying assumption by creating separate deltas.

Not sure if this is intended or an edge case — just wanted to report the discrepancy in case it's a bug.

Ambitious-idiot avatar Nov 27 '25 07:11 Ambitious-idiot

@Ambitious-idiot Thanks for the report. I wrote a small reproducer:

from transformers import AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

model_id = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_id)
# config = LoraConfig(trainable_token_indices=[1, 2])  # works
config = LoraConfig(trainable_token_indices={"lm_head": [1, 2], "embed_tokens": [1, 2]})  # fails
model = get_peft_model(model, config)
w_head = model.base_model.model.lm_head.token_adapter.trainable_tokens_delta.default
w_embed = model.base_model.model.model.decoder.embed_tokens.token_adapter.trainable_tokens_delta.default
assert w_head.data_ptr() == w_embed.data_ptr()

Indeed, this fails, but it will pass once PR #2870 is merged.

@sambhavnoobcoder: On that PR, there is a corresponding test, test_ensure_weight_tying_applied_with_same_indices, but unless I'm missing something, there is no test case for trainable_token_indices=[1, 2]. Could you please parametrize said test to check both? Also, the outcome should be the same for both ensure_weight_tying=True and False, right? So let's parametrize over this parameter too.

BenjaminBossan avatar Nov 27 '25 10:11 BenjaminBossan

Okay @BenjaminBossan , I'll make that change in that PR . Thank you @Ambitious-idiot for pointing this out.

sambhavnoobcoder avatar Nov 27 '25 10:11 sambhavnoobcoder

Update : @BenjaminBossan added the changes in test in , checked it out , everything works and all tests pass .

sambhavnoobcoder avatar Nov 27 '25 17:11 sambhavnoobcoder