torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Kd recipe update

Open rajuptvs opened this issue 9 months ago • 8 comments
trafficstars

Context

What is the purpose of this PR? Is it to

  • [x] add a new feature
  • [ ] fix a bug
  • [ ] update tests and/or documentation
  • [ ] other (please add here)

Please link to any issues this PR addresses. #1959

Changelog

What are the changes made in this PR? This PR allows the usage of activation offloading and optimizer in backward in knowledge distillation recipes similar to #1847 and #1833 allowing reduction in memory usage.

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • [x] run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • [ ] add unit tests for any new functionality
  • [x] update docstrings for any new or updated methods or classes
  • [ ] run unit tests via pytest tests
  • [ ] run recipe tests via pytest tests -m integration_test
  • [x] manually run any new or modified recipes with sufficient proof of correctness
  • [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

Below are wandb runs for each: Activation Offloading (Single): single_device_activation_offloading

Activation Offloading (Distributed):

distributed_activation_offloading

Optimizer in Backward (Single):

opt_in_bwd_single_device

Optimizer in Backward (Distributed):

opt_in_bwd_distributed

Config.yaml for activation_offloading

output_dir: /tmp/torchtune/qwen2_1_5_to_0_5B/KD_lora_distributed
model:
  _component_: torchtune.models.qwen2.lora_qwen2_0_5b
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  apply_lora_to_mlp: true
  lora_rank: 32
  lora_alpha: 64
teacher_model:
  _component_: torchtune.models.qwen2.qwen2_1_5b
tokenizer:
  _component_: torchtune.models.qwen2.qwen2_tokenizer
  path: /tmp/Qwen2-0.5B-Instruct/vocab.json
  merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
  max_seq_len: null
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
  checkpoint_files:
  - model.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: QWEN2
teacher_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Qwen2-1.5B-Instruct
  checkpoint_files:
  - model.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: QWEN2
resume_from_checkpoint: false
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
seed: null
shuffle: true
batch_size: 8
optimizer:
  _component_: torch.optim.AdamW
  weight_decay: 0.01
  lr: 0.0003
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
kd_loss:
  _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5
epochs: 1
max_steps_per_epoch: null
compile: false
gradient_accumulation_steps: 8
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
  output_dir: ${output_dir}/profiling_outputs
  cpu: true
  cuda: true
  profile_memory: false
  with_stack: false
  record_shapes: true
  with_flops: false
  wait_steps: 5
  warmup_steps: 5
  active_steps: 2
  num_cycles: 1
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune

Config.yaml for opt_in_bwd

output_dir: /tmp/torchtune/qwen2_1_5_to_0_5B/KD_lora_distributed
model:
  _component_: torchtune.models.qwen2.lora_qwen2_0_5b
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  apply_lora_to_mlp: true
  lora_rank: 32
  lora_alpha: 64
teacher_model:
  _component_: torchtune.models.qwen2.qwen2_1_5b
tokenizer:
  _component_: torchtune.models.qwen2.qwen2_tokenizer
  path: /tmp/Qwen2-0.5B-Instruct/vocab.json
  merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
  max_seq_len: null
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
  checkpoint_files:
  - model.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: QWEN2
teacher_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: /tmp/Qwen2-1.5B-Instruct
  checkpoint_files:
  - model.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: QWEN2
resume_from_checkpoint: false
dataset:
  _component_: torchtune.datasets.alpaca_cleaned_dataset
  packed: false
seed: null
shuffle: true
batch_size: 8
optimizer:
  _component_: torch.optim.AdamW
  weight_decay: 0.01
  lr: 0.0003
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.modules.loss.CEWithChunkedOutputLoss
kd_loss:
  _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5
epochs: 1
max_steps_per_epoch: null
compile: false
gradient_accumulation_steps: 1
device: cuda
dtype: bf16
enable_activation_checkpointing: false
enable_activation_offloading: false
optimizer_in_bwd: true
profiler:
  _component_: torchtune.training.setup_torch_profiler
  enabled: false
  output_dir: ${output_dir}/profiling_outputs
  cpu: true
  cuda: true
  profile_memory: false
  with_stack: false
  record_shapes: true
  with_flops: false
  wait_steps: 5
  warmup_steps: 5
  active_steps: 2
  num_cycles: 1
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune

command used to run distributed kd

tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config custom.yaml

command used to run single kd recipe

tune run knowledge_distillation_single_device --config custom.yaml

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Here is a docstring example and a tutorial example

  • [ ] I did not change any public API
  • [ ] I have added an example to docs or docstrings

rajuptvs avatar Feb 14 '25 17:02 rajuptvs

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2395

Note: Links to docs will display an error until the docs builds have been completed.

:x: 2 New Failures, 1 Cancelled Job, 1 Unrelated Failure

As of commit 3220b41072776825c5dbf3837a14432ebe7a00e7 with merge base 8fd697188f25832343cc013b89b354f0f8368b78 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar Feb 14 '25 17:02 pytorch-bot[bot]

Hi @rajuptvs very happy to see your PR 🙌. Hope you enjoyed working on it, and I hope to see more PRs from you 😁

Regarding the PR, can you please update the recipes as well? I don't see optimizer_in_bwd used in any KD config. Also, can you please add your config yaml, and tune run command in the PR description? It would be much easier to test and reproduce.

Ankur-singh avatar Feb 26 '25 02:02 Ankur-singh

Hey @Ankur-singh , really loved working on the pr. Thanks for pushing me towards this. I want to keep contributing to the project.

Sure, Sorry for the miss, I'll try to upload the configs over the weekend and update the recipes as well.

rajuptvs avatar Mar 08 '25 01:03 rajuptvs

Hey @Ankur-singh , apologies for the delay. I have updated the recipies as well as added the config.yaml and tune run command. Please do let me know if i missed anything else.

rajuptvs avatar Mar 30 '25 16:03 rajuptvs

Can you rebase on main?

joecummings avatar Mar 31 '25 20:03 joecummings

Can you rebase on main?

Hi Joe.. apologies for the delay, I have rebased on the main. Please let me know if any other changes are needed. Thanks in advance

rajuptvs avatar Apr 17 '25 18:04 rajuptvs

Codecov Report

Attention: Patch coverage is 0% with 62 lines in your changes missing coverage. Please review.

Project coverage is 64.73%. Comparing base (f3e4747) to head (3220b41). Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
recipes/knowledge_distillation_distributed.py 0.00% 62 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2395      +/-   ##
==========================================
- Coverage   65.78%   64.73%   -1.06%     
==========================================
  Files         396      395       -1     
  Lines       23764    23560     -204     
==========================================
- Hits        15634    15251     -383     
- Misses       8130     8309     +179     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov-commenter avatar Apr 22 '25 16:04 codecov-commenter

This looks like it's ready to land once we get the tests passing. Can you please merge main and resolve any conflicts and verify all tests pass?

pbontrager avatar May 05 '25 16:05 pbontrager