peft
peft copied to clipboard
Implement ensure_weight_tying for trainable_token_indices (#2864)
Implement ensure_weight_tying for trainable_token_indices
Summary
This PR implements consistent weight tying behavior for trainable_token_indices as specified in issue #2864. It extends the ensure_weight_tying parameter (introduced in PR #2803) to work with trainable_token_indices, providing users explicit control over weight tying between embeddings and LM head.
Fixes #2864 (trainable_token_indices portion)
Problem Statement
Background
PEFT models sometimes need to handle tied weights between embedding layers and LM head layers (when tie_word_embeddings=True). The ensure_weight_tying parameter was introduced in PR #2803 to give users explicit control over this behavior for modules_to_save. However, the same control was missing for trainable_token_indices.
The Issue
Issue identified that the weight tying behavior for trainable_token_indices was not consistent across different scenarios. Specifically, there were four cases that needed to be implemented:
- Untied model with ensure_weight_tying=True: Should warn users that weight tying cannot be applied
- Tied model with ensure_weight_tying=True and different indices: Should error, as it's impossible to tie adapters with different token indices
- Tied model with ensure_weight_tying=False and different indices: Should treat layers as separate (backwards compatibility behavior)
- Tied model with ensure_weight_tying=True and same indices: Should apply weight tying correctly
Solution Approach
Implementation Strategy:
- Check weight tying configuration early (before creating wrappers)
- Detect if user specified both embedding and lm_head layers in dict format
- Check if their token indices match or differ
- Apply appropriate logic based on the configuration matrix from the issue
- Skip creating wrappers for layers that will be tied later
Changes Made
1. Updated Configuration Documentation
File: src/peft/tuners/lora/config.py
Updated the ensure_weight_tying parameter docstring to clarify that it now applies to both modules_to_save and trainable_token_indices, making the documentation consistent with the implementation.
2. Implemented Weight Tying Logic
File: src/peft/utils/other.py
Added comprehensive logic within the existing trainable_token_indices handling block:
Key Components:
- Early Detection: Check weight tying configuration before creating any wrappers
- Layer Detection: Identify if both embedding and lm_head layers are specified
- Index Comparison: Determine if token indices match between the layers
- Skip Logic: Prevent double-wrapping by skipping layers that will be tied
- Warning System: Inform users when their configuration cannot be applied
- Error Handling: Raise clear errors for contradictory configurations
- Backwards Compatibility: Preserve existing behavior when
ensure_weight_tying=False
Four Cases Implemented:
-
Case 1 - Warning for Untied Models:
- When:
weights_tied=False+ensure_weight_tying=True - Action: Issue warning that weight tying cannot be applied
- Rationale: Model doesn't have tied weights, so user's request cannot be fulfilled
- When:
-
Case 2 - Error for Contradictory Configuration:
- When:
weights_tied=True+ensure_weight_tying=True+ different indices - Action: Raise ValueError with clear explanation
- Rationale: Cannot tie adapters that operate on different token indices
- When:
-
Case 3 - Backwards Compatibility:
- When:
weights_tied=True+ensure_weight_tying=False+ different indices - Action: Treat layers as separate (no tying)
- Rationale: User explicitly opted out, respect their choice even if model supports tying
- When:
-
Case 4 - Apply Tying:
- When: Other combinations where tying is appropriate
- Action: Create tied adapters that share parameters
- Rationale: Normal weight tying behavior
3. Comprehensive Test Suite
File: tests/test_trainable_tokens.py
Added 7 new test methods covering all scenarios:
Test Coverage:
test_ensure_weight_tying_warns_when_model_not_tied_list_format: Verifies warning for list formattest_ensure_weight_tying_warns_when_model_not_tied_dict_format: Verifies warning for dict formattest_weight_tying_bc_different_indices_treated_separately: Verifies backwards compatibilitytest_ensure_weight_tying_errors_with_different_indices: Verifies error for contradictory configtest_ensure_weight_tying_applied_with_same_indices: Verifies tying with same indicestest_weight_tying_bc_same_indices_applied: Verifies BC for same indicestest_ensure_weight_tying_with_single_layer: Verifies list format tying
Testing Results
New Tests
All 7 new tests pass successfully:
- β
test_ensure_weight_tying_warns_when_model_not_tied_list_format - β
test_ensure_weight_tying_warns_when_model_not_tied_dict_format - β
test_weight_tying_bc_different_indices_treated_separately - β
test_ensure_weight_tying_errors_with_different_indices - β
test_ensure_weight_tying_applied_with_same_indices - β
test_weight_tying_bc_same_indices_applied - β
test_ensure_weight_tying_with_single_layer
Backwards Compatibility
This implementation maintains full backwards compatibility:
β
Default Behavior Unchanged: ensure_weight_tying defaults to False, preserving existing behavior
β
No Breaking Changes: Existing code continues to work without modification
β
Opt-in Enhancement: Users must explicitly set ensure_weight_tying=True to use new features
β
BC Mode Preserved: When ensure_weight_tying=False, existing automatic tying still works for compatible configurations
Screenshots
Checklist
- [x] Implementation follows the specification in issue #2864
- [x] All 7 new tests pass
- [x] Backwards compatibility maintained
- [x] Documentation updated (docstring)
- [x] Code is scoped only to trainable_token_indices
- [x] Error messages are clear and actionable
- [x] Warning messages inform users appropriately
cc: @BenjaminBossan
About the test coverage , the table looks correct. I've filled all 6 gaps in the test coverage:
- Added 2 new standalone test functions (
test_untied_model_list_format_no_ensureandtest_tied_model_list_format_no_ensure) - Expanded the parametrized
test_ensure_weight_tying_warns_when_model_not_tiedfrom 2 to 4 scenarios (adding the dict format cases) - Added parametrized
test_untied_model_dict_no_ensurecovering 2 scenarios (same and different indices)
@BenjaminBossan Thank you for the detailed review . i have made all the changes and would appreciate if you could have a look at it again . I'll make any changes necessary .
@sambhavnoobcoder Are you still working on this?
oh hi @BenjaminBossan , actually i had resolved all the comments already , just forgot to tag you for a review . just resolved a small merge conflict , this i already for your reviewing / merging now .
@BenjaminBossan resolved this tooβplease re-review, and we can merge if all looks good.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
Thanks @sambhavnoobcoder, the PR looks good from my side.
@githubnemo could you please also do a review :pray:?