peft icon indicating copy to clipboard operation
peft copied to clipboard

ENH: Tie weights for target_modules in Lora (#2864)

Open romitjain opened this issue 1 month ago • 8 comments

Solves #2864 for target_modules

Enables ensure_weight_tying flag in LoraConfig for target_modules.

For LoRA, if any of the tied layers are added to target_modules and ensure_weight_tying == True, the adapters added to the layer are shared with all the tied layers.

For example, if a model has tied weights and target_modules=['embed_tokens'] then, LoRA adapters are added to both embed_tokens and lm_head. The adapters in lm_head share the weights with the adapters added to embed_tokens

romitjain avatar Oct 29 '25 06:10 romitjain

@BenjaminBossan I have added the relevant test cases and implemented the ensure_weight_tying flag for target_modules. The current implementation works only if embed_tokens is added and not if lm_head is added. I will implement that fix and update the PR, but meanwhile would appreciate your views on the logic and implementation.

At a high level

  1. I have updated BaseTuner._check_tied_modules to check for tied modules in target_modules
  2. I have added a private method BaseTuner._add_targets_to_tie that needs to be implemented by the inheriting classes
  3. I have added a loop in BaseTuner.inject_adapter to tie the adapters. I have implemented this extra loop to ensure that the order in which adapters are added to the target modules do not matter.

Thank you

romitjain avatar Oct 29 '25 06:10 romitjain

@BenjaminBossan This is now ready for review. I have also updated the logic for tied layers in modules_to_save so that lm_head and [embed_tokens, lm_head] cases are supported. Earlier, they would not have worked. The high level implementation remains the same but according to me it's much better placed then my earlier commits.

I have also added a few tests for the above case, and all of the tests pass.

The only thing remaining is how to check for target_modules in case it's a string. I will come back to it, but you can go ahead and review the core logic.

romitjain avatar Oct 31 '25 12:10 romitjain

@BenjaminBossan I have addressed your comments. PTAL

romitjain avatar Nov 05 '25 12:11 romitjain

@BenjaminBossan please review the latest changes now. I believe I have addressed all your comments, but let me know if I missed something.

I have added test cases where we are passing target_modules as str and added a (slightly) hacky solve for that. It's opinionated to keep the flow simple.

Regarding the transformers v5 update, since we would be having a version locked in peft, I believe if this PR advances faster than that, we can merge this. I can take up changes too whenever they're needed. However, you are much closer to this, so you can decide and let me know.

romitjain avatar Nov 07 '25 13:11 romitjain

@BenjaminBossan I have made the updates. As of now, we have 3 outstanding updates we need to resolve

  1. What to do in case target_modules is a str: https://github.com/huggingface/peft/pull/2879#discussion_r2503550670: I have replied to the comment with a small modification in the logic
  2. Release of this PR: Can this PR be merged to main but not in release? Or do we have an ETA for v5 arrival? Some of our internal PRs depend on this, hence asking.
  3. (Minor) Duplication of a warning message: https://github.com/huggingface/peft/pull/2879#discussion_r2510738148L I have replied to the comment

romitjain avatar Nov 12 '25 07:11 romitjain

  1. https://github.com/huggingface/peft/pull/2879#discussion_r2518948684: Good suggestion. I have added a new function to find a layer by reference tensors
  2. I have resolved the merge conflict and addressed your remaining comments
  3. Here's a psudo-code of the complete flow that is currently implemented
def inject_adapter():
    ...

    # finds tied modules and adds to a set - target_modules_to_tie
    self._check_tied_modules(model, peft_config)

    tied_targets = []

    # loop 1
    for k in keys:
        result = k in target_modules
        to_tie = k in target_modules_to_tie

        if to_tie:
            tied_targets.append(k)
            continue
        if result:
            add_lora(k)

    # loop 2
    for t in tied_targets:
        add_tied_lora(t)

    ...
  1. Here is the flow of the alternate that you suggested
def inject_adapter():
    ...
    # loop 1
    for k in keys:
        result = k in target_modules
        if result:
            add_lora(k)


    # finds tied modules and adds to a set - target_modules_to_tie
    self._check_tied_modules(model, peft_config)

    # this is needed since all the tied adapters reference embedding
    # layer's lora as source
    add_lora(embed_tokens)

    # loop 2
    for t in target_modules_to_tie:
        remove_lora(t) # will remove lora if it exists, this might be a bit complex
        add_tied_lora(t)

    ...

romitjain avatar Nov 13 '25 08:11 romitjain

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan, I have fixed the test. I have removed 3 redundant tests which are no longer required.

romitjain avatar Nov 14 '25 12:11 romitjain

@BenjaminBossan Resolved your comment

romitjain avatar Nov 18 '25 16:11 romitjain

@BenjaminBossan I made a small commit - in one of my earlier commits, I had made a change where the target modules were saved as model.embed_tokens instead of embed_tokens. This was causing some downstream issues. Apologies for the late commit, I know that you will have to trigger the runs again 😅

romitjain avatar Nov 19 '25 14:11 romitjain

Thanks for the update @romitjain. There is a new merge conflict, could you please check?

BenjaminBossan avatar Nov 20 '25 16:11 BenjaminBossan

@BenjaminBossan Done

romitjain avatar Nov 20 '25 16:11 romitjain

@BenjaminBossan Let me know if any steps are remaining from my side for final push?

romitjain avatar Nov 21 '25 11:11 romitjain

@romitjain No, thank you, let's wait for @githubnemo's review.

BenjaminBossan avatar Nov 21 '25 14:11 BenjaminBossan

Hi @githubnemo, it would be very helpful if you could review the PR. One of our internal features depends on this :)

romitjain avatar Nov 28 '25 14:11 romitjain