peft icon indicating copy to clipboard operation
peft copied to clipboard

Implement ensure_weight_tying for trainable_token_indices (#2864)

Open sambhavnoobcoder opened this issue 1 month ago β€’ 2 comments

Implement ensure_weight_tying for trainable_token_indices

Summary

This PR implements consistent weight tying behavior for trainable_token_indices as specified in issue #2864. It extends the ensure_weight_tying parameter (introduced in PR #2803) to work with trainable_token_indices, providing users explicit control over weight tying between embeddings and LM head.

Fixes #2864 (trainable_token_indices portion)


Problem Statement

Background

PEFT models sometimes need to handle tied weights between embedding layers and LM head layers (when tie_word_embeddings=True). The ensure_weight_tying parameter was introduced in PR #2803 to give users explicit control over this behavior for modules_to_save. However, the same control was missing for trainable_token_indices.

The Issue

Issue identified that the weight tying behavior for trainable_token_indices was not consistent across different scenarios. Specifically, there were four cases that needed to be implemented:

  1. Untied model with ensure_weight_tying=True: Should warn users that weight tying cannot be applied
  2. Tied model with ensure_weight_tying=True and different indices: Should error, as it's impossible to tie adapters with different token indices
  3. Tied model with ensure_weight_tying=False and different indices: Should treat layers as separate (backwards compatibility behavior)
  4. Tied model with ensure_weight_tying=True and same indices: Should apply weight tying correctly

Solution Approach

Implementation Strategy:

  1. Check weight tying configuration early (before creating wrappers)
  2. Detect if user specified both embedding and lm_head layers in dict format
  3. Check if their token indices match or differ
  4. Apply appropriate logic based on the configuration matrix from the issue
  5. Skip creating wrappers for layers that will be tied later

Changes Made

1. Updated Configuration Documentation

File: src/peft/tuners/lora/config.py

Updated the ensure_weight_tying parameter docstring to clarify that it now applies to both modules_to_save and trainable_token_indices, making the documentation consistent with the implementation.

2. Implemented Weight Tying Logic

File: src/peft/utils/other.py

Added comprehensive logic within the existing trainable_token_indices handling block:

Key Components:

  • Early Detection: Check weight tying configuration before creating any wrappers
  • Layer Detection: Identify if both embedding and lm_head layers are specified
  • Index Comparison: Determine if token indices match between the layers
  • Skip Logic: Prevent double-wrapping by skipping layers that will be tied
  • Warning System: Inform users when their configuration cannot be applied
  • Error Handling: Raise clear errors for contradictory configurations
  • Backwards Compatibility: Preserve existing behavior when ensure_weight_tying=False

Four Cases Implemented:

  1. Case 1 - Warning for Untied Models:

    • When: weights_tied=False + ensure_weight_tying=True
    • Action: Issue warning that weight tying cannot be applied
    • Rationale: Model doesn't have tied weights, so user's request cannot be fulfilled
  2. Case 2 - Error for Contradictory Configuration:

    • When: weights_tied=True + ensure_weight_tying=True + different indices
    • Action: Raise ValueError with clear explanation
    • Rationale: Cannot tie adapters that operate on different token indices
  3. Case 3 - Backwards Compatibility:

    • When: weights_tied=True + ensure_weight_tying=False + different indices
    • Action: Treat layers as separate (no tying)
    • Rationale: User explicitly opted out, respect their choice even if model supports tying
  4. Case 4 - Apply Tying:

    • When: Other combinations where tying is appropriate
    • Action: Create tied adapters that share parameters
    • Rationale: Normal weight tying behavior

3. Comprehensive Test Suite

File: tests/test_trainable_tokens.py

Added 7 new test methods covering all scenarios:

Test Coverage:

  • test_ensure_weight_tying_warns_when_model_not_tied_list_format: Verifies warning for list format
  • test_ensure_weight_tying_warns_when_model_not_tied_dict_format: Verifies warning for dict format
  • test_weight_tying_bc_different_indices_treated_separately: Verifies backwards compatibility
  • test_ensure_weight_tying_errors_with_different_indices: Verifies error for contradictory config
  • test_ensure_weight_tying_applied_with_same_indices: Verifies tying with same indices
  • test_weight_tying_bc_same_indices_applied: Verifies BC for same indices
  • test_ensure_weight_tying_with_single_layer: Verifies list format tying

Testing Results

New Tests

All 7 new tests pass successfully:

  • βœ… test_ensure_weight_tying_warns_when_model_not_tied_list_format
  • βœ… test_ensure_weight_tying_warns_when_model_not_tied_dict_format
  • βœ… test_weight_tying_bc_different_indices_treated_separately
  • βœ… test_ensure_weight_tying_errors_with_different_indices
  • βœ… test_ensure_weight_tying_applied_with_same_indices
  • βœ… test_weight_tying_bc_same_indices_applied
  • βœ… test_ensure_weight_tying_with_single_layer

Backwards Compatibility

This implementation maintains full backwards compatibility:

βœ… Default Behavior Unchanged: ensure_weight_tying defaults to False, preserving existing behavior βœ… No Breaking Changes: Existing code continues to work without modification βœ… Opt-in Enhancement: Users must explicitly set ensure_weight_tying=True to use new features βœ… BC Mode Preserved: When ensure_weight_tying=False, existing automatic tying still works for compatible configurations


Screenshots

Screenshot 2025-10-26 at 7 20 09β€―PM

Checklist

  • [x] Implementation follows the specification in issue #2864
  • [x] All 7 new tests pass
  • [x] Backwards compatibility maintained
  • [x] Documentation updated (docstring)
  • [x] Code is scoped only to trainable_token_indices
  • [x] Error messages are clear and actionable
  • [x] Warning messages inform users appropriately

cc: @BenjaminBossan


sambhavnoobcoder avatar Oct 26 '25 14:10 sambhavnoobcoder

About the test coverage , the table looks correct. I've filled all 6 gaps in the test coverage:

  • Added 2 new standalone test functions (test_untied_model_list_format_no_ensure and test_tied_model_list_format_no_ensure)
  • Expanded the parametrized test_ensure_weight_tying_warns_when_model_not_tied from 2 to 4 scenarios (adding the dict format cases)
  • Added parametrized test_untied_model_dict_no_ensure covering 2 scenarios (same and different indices)

sambhavnoobcoder avatar Oct 29 '25 19:10 sambhavnoobcoder

@BenjaminBossan Thank you for the detailed review . i have made all the changes and would appreciate if you could have a look at it again . I'll make any changes necessary .

sambhavnoobcoder avatar Oct 29 '25 19:10 sambhavnoobcoder

@sambhavnoobcoder Are you still working on this?

BenjaminBossan avatar Nov 18 '25 13:11 BenjaminBossan

oh hi @BenjaminBossan , actually i had resolved all the comments already , just forgot to tag you for a review . just resolved a small merge conflict , this i already for your reviewing / merging now .

sambhavnoobcoder avatar Nov 18 '25 15:11 sambhavnoobcoder

@BenjaminBossan resolved this tooβ€”please re-review, and we can merge if all looks good.

sambhavnoobcoder avatar Nov 19 '25 18:11 sambhavnoobcoder

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Thanks @sambhavnoobcoder, the PR looks good from my side.

@githubnemo could you please also do a review :pray:?

BenjaminBossan avatar Nov 21 '25 14:11 BenjaminBossan