ENH: Tie weights for target_modules in Lora (#2864)
Solves #2864 for target_modules
Enables ensure_weight_tying flag in LoraConfig for target_modules.
For LoRA, if any of the tied layers are added to target_modules and ensure_weight_tying == True, the adapters added to the layer are shared with all the tied layers.
For example, if a model has tied weights and target_modules=['embed_tokens'] then, LoRA adapters are added to both embed_tokens and lm_head. The adapters in lm_head share the weights with the adapters added to embed_tokens
@BenjaminBossan
I have added the relevant test cases and implemented the ensure_weight_tying flag for target_modules. The current implementation works only if embed_tokens is added and not if lm_head is added. I will implement that fix and update the PR, but meanwhile would appreciate your views on the logic and implementation.
At a high level
- I have updated
BaseTuner._check_tied_modulesto check for tied modules intarget_modules - I have added a private method
BaseTuner._add_targets_to_tiethat needs to be implemented by the inheriting classes - I have added a loop in
BaseTuner.inject_adapterto tie the adapters. I have implemented this extra loop to ensure that the order in which adapters are added to the target modules do not matter.
Thank you
@BenjaminBossan This is now ready for review. I have also updated the logic for tied layers in modules_to_save so that lm_head and [embed_tokens, lm_head] cases are supported. Earlier, they would not have worked. The high level implementation remains the same but according to me it's much better placed then my earlier commits.
I have also added a few tests for the above case, and all of the tests pass.
The only thing remaining is how to check for target_modules in case it's a string. I will come back to it, but you can go ahead and review the core logic.
@BenjaminBossan I have addressed your comments. PTAL
@BenjaminBossan please review the latest changes now. I believe I have addressed all your comments, but let me know if I missed something.
I have added test cases where we are passing target_modules as str and added a (slightly) hacky solve for that. It's opinionated to keep the flow simple.
Regarding the transformers v5 update, since we would be having a version locked in peft, I believe if this PR advances faster than that, we can merge this. I can take up changes too whenever they're needed. However, you are much closer to this, so you can decide and let me know.
@BenjaminBossan I have made the updates. As of now, we have 3 outstanding updates we need to resolve
- What to do in case
target_modulesis a str: https://github.com/huggingface/peft/pull/2879#discussion_r2503550670: I have replied to the comment with a small modification in the logic - Release of this PR: Can this PR be merged to main but not in release? Or do we have an ETA for v5 arrival? Some of our internal PRs depend on this, hence asking.
- (Minor) Duplication of a warning message: https://github.com/huggingface/peft/pull/2879#discussion_r2510738148L I have replied to the comment
- https://github.com/huggingface/peft/pull/2879#discussion_r2518948684: Good suggestion. I have added a new function to find a layer by reference tensors
- I have resolved the merge conflict and addressed your remaining comments
- Here's a psudo-code of the complete flow that is currently implemented
def inject_adapter():
...
# finds tied modules and adds to a set - target_modules_to_tie
self._check_tied_modules(model, peft_config)
tied_targets = []
# loop 1
for k in keys:
result = k in target_modules
to_tie = k in target_modules_to_tie
if to_tie:
tied_targets.append(k)
continue
if result:
add_lora(k)
# loop 2
for t in tied_targets:
add_tied_lora(t)
...
- Here is the flow of the alternate that you suggested
def inject_adapter():
...
# loop 1
for k in keys:
result = k in target_modules
if result:
add_lora(k)
# finds tied modules and adds to a set - target_modules_to_tie
self._check_tied_modules(model, peft_config)
# this is needed since all the tied adapters reference embedding
# layer's lora as source
add_lora(embed_tokens)
# loop 2
for t in target_modules_to_tie:
remove_lora(t) # will remove lora if it exists, this might be a bit complex
add_tied_lora(t)
...
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
@BenjaminBossan, I have fixed the test. I have removed 3 redundant tests which are no longer required.
@BenjaminBossan Resolved your comment
@BenjaminBossan I made a small commit - in one of my earlier commits, I had made a change where the target modules were saved as model.embed_tokens instead of embed_tokens. This was causing some downstream issues.
Apologies for the late commit, I know that you will have to trigger the runs again 😅
Thanks for the update @romitjain. There is a new merge conflict, could you please check?
@BenjaminBossan Done
@BenjaminBossan Let me know if any steps are remaining from my side for final push?
@romitjain No, thank you, let's wait for @githubnemo's review.
Hi @githubnemo, it would be very helpful if you could review the PR. One of our internal features depends on this :)