torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Save adapter config and remapped adapter weights for loading into PEFT

Open ebsmothers opened this issue 1 year ago • 1 comments

This is a PR for integration with PEFT to allow continued fine-tuning of checkpoints from torchtune. We save a file adapter_config.json, along with adapter_model.bin to match the format expected by PEFT. We also remap the LoRA weights to match the HF format (due to differences in RoPE implementations).

The save logic differs depending on checkpointer and model. In summary:

  • For Meta checkpointer and tune checkpointer we do not make any changes, and continue to save adapter weights in the tune format. This is to be consistent with the principle of same input format, same output format (since PEFT output format only matches HF format).
  • For HF checkpointer we still output tune format adapter weights (to allow resumption from intermediate checkpoints), but we also output HF-mapped adapter weights except for in the case of phi-3 models. See this comment for the rationale.

Testing:

Unit tests

Added unit test in test_checkpointer.py to verify that adapter config and PEFT-compatible weights are saved as expected

pytest tests/torchtune/utils/test_checkpointer.py
...
======= 6 passed in 1.19s ==========

Recipe tests

pytest -m integration_test tests/recipes
...
==== 18 passed, 1 deselected, 3 warnings in 167.99s (0:02:47) =======

Manual E2E test

First create the file test_peft_integration.py as in this gist.

(1) ✅ Permute of LoRA weights works as expected (i.e. _permute_lora_matrix(B) * A = _permute(B*A), which I think is what we want). (2) ✅ Uploaded adapter weights can be loaded into a transformers model via from_pretrained (3) ✅ Model forwards match within a reasonable tolerance across PEFT-loaded and torchtune-loaded checkpoints

For (3):

Test case 1: default config (Q and V only)

tune run lora_finetune_single_device --config llama2/7B_lora_single_device gradient_accumulation_steps=1 \
max_steps_per_epoch=500 dtype=fp32 checkpointer.output_dir=/data/users/ebs/test_peft_integration

to save a fine-tuned LoRA checkpoint with adapter config and adapter weights in PEFT format. Then to compare forward pass when loading our fine-tuned checkpoint into PEFT vs into torchtune:

python3 test_peft_integration.py --checkpoint-dir=/data/users/ebs/test_peft_integration
...
Maximum difference: 9.298324584960938e-05

Test case 2: all layers, custom LoRA rank and alpha

tune run lora_finetune_single_device --config llama2/7B_lora_single_device \
model.lora_attn_modules=['q_proj','k_proj','v_proj','output_proj'] model.apply_lora_to_mlp=True \
model.apply_lora_to_output=True gradient_accumulation_steps=1 max_steps_per_epoch=100 model.lora_rank=16 \
model.lora_alpha=64 dtype=fp32 checkpointer.output_dir=/data/users/ebs/test_peft_integration_full

Then

python3 test_peft_integration.py --checkpoint-dir=/data/users/ebs/test_peft_integration_full
...
Maximum difference: 0.000152587890625

ebsmothers avatar May 03 '24 22:05 ebsmothers

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/933

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 9eb9b68a4145fa552c152a6a77603997c640c3a0 with merge base 29ae975fc6d2f8e85ce33634116e0bda0472253c (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar May 03 '24 22:05 pytorch-bot[bot]