torchtune
torchtune copied to clipboard
Kd recipe update
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [ ] update tests and/or documentation
- [ ] other (please add here)
Please link to any issues this PR addresses. #1959
Changelog
What are the changes made in this PR? This PR allows the usage of activation offloading and optimizer in backward in knowledge distillation recipes similar to #1847 and #1833 allowing reduction in memory usage.
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
- [x] run pre-commit hooks and linters (make sure you've first installed via
pre-commit install) - [ ] add unit tests for any new functionality
- [x] update docstrings for any new or updated methods or classes
- [ ] run unit tests via
pytest tests - [ ] run recipe tests via
pytest tests -m integration_test - [x] manually run any new or modified recipes with sufficient proof of correctness
- [ ] include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)
Below are wandb runs for each:
Activation Offloading (Single):
Activation Offloading (Distributed):
Optimizer in Backward (Single):
Optimizer in Backward (Distributed):
Config.yaml for activation_offloading
output_dir: /tmp/torchtune/qwen2_1_5_to_0_5B/KD_lora_distributed
model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
lora_attn_modules:
- q_proj
- v_proj
- output_proj
apply_lora_to_mlp: true
lora_rank: 32
lora_alpha: 64
teacher_model:
_component_: torchtune.models.qwen2.qwen2_1_5b
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
max_seq_len: null
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_files:
- model.safetensors
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct
checkpoint_files:
- model.safetensors
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: false
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
seed: null
shuffle: true
batch_size: 8
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 0.0003
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
kd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5
epochs: 1
max_steps_per_epoch: null
compile: false
gradient_accumulation_steps: 8
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: true
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: false
output_dir: ${output_dir}/profiling_outputs
cpu: true
cuda: true
profile_memory: false
with_stack: false
record_shapes: true
with_flops: false
wait_steps: 5
warmup_steps: 5
active_steps: 2
num_cycles: 1
metric_logger:
_component_: torchtune.training.metric_logging.WandBLogger
log_dir: ${output_dir}/logs
project: torchtune
Config.yaml for opt_in_bwd
output_dir: /tmp/torchtune/qwen2_1_5_to_0_5B/KD_lora_distributed
model:
_component_: torchtune.models.qwen2.lora_qwen2_0_5b
lora_attn_modules:
- q_proj
- v_proj
- output_proj
apply_lora_to_mlp: true
lora_rank: 32
lora_alpha: 64
teacher_model:
_component_: torchtune.models.qwen2.qwen2_1_5b
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
max_seq_len: null
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_files:
- model.safetensors
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
teacher_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-1.5B-Instruct
checkpoint_files:
- model.safetensors
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: false
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: false
seed: null
shuffle: true
batch_size: 8
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 0.0003
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
kd_loss:
_component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss
kd_ratio: 0.5
epochs: 1
max_steps_per_epoch: null
compile: false
gradient_accumulation_steps: 1
device: cuda
dtype: bf16
enable_activation_checkpointing: false
enable_activation_offloading: false
optimizer_in_bwd: true
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: false
output_dir: ${output_dir}/profiling_outputs
cpu: true
cuda: true
profile_memory: false
with_stack: false
record_shapes: true
with_flops: false
wait_steps: 5
warmup_steps: 5
active_steps: 2
num_cycles: 1
metric_logger:
_component_: torchtune.training.metric_logging.WandBLogger
log_dir: ${output_dir}/logs
project: torchtune
command used to run distributed kd
tune run --nnodes 1 --nproc_per_node 2 knowledge_distillation_distributed --config custom.yaml
command used to run single kd recipe
tune run knowledge_distillation_single_device --config custom.yaml
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it. Here is a docstring example and a tutorial example
- [ ] I did not change any public API
- [ ] I have added an example to docs or docstrings
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2395
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:x: 2 New Failures, 1 Cancelled Job, 1 Unrelated Failure
As of commit 3220b41072776825c5dbf3837a14432ebe7a00e7 with merge base 8fd697188f25832343cc013b89b354f0f8368b78 ():
NEW FAILURES - The following jobs have failed:
- GPU tests / gpu_test (3.10, stable) (gh)
tests/recipes/test_knowledge_distillation_single_device.py::TestKDSingleDeviceRecipe::test_save_and_load_merged_weights - Lint / lint (3.10) (gh)
Process completed with exit code 1.
CANCELLED JOB - The following job was cancelled. Please retry:
- GPU tests / gpu_test (3.11, stable) (gh)
tests/recipes/test_knowledge_distillation_single_device.py::TestKDSingleDeviceRecipe::test_save_and_load_merged_weights
BROKEN TRUNK - The following job failed but were present on the merge base:
👉 Rebase onto the `viable/strict` branch to avoid these failures
- GPU tests / gpu_test (3.9, stable) (gh) (trunk failure)
##[error]The operation was canceled.
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Hi @rajuptvs very happy to see your PR 🙌. Hope you enjoyed working on it, and I hope to see more PRs from you 😁
Regarding the PR, can you please update the recipes as well? I don't see optimizer_in_bwd used in any KD config. Also, can you please add your config yaml, and tune run command in the PR description? It would be much easier to test and reproduce.
Hey @Ankur-singh , really loved working on the pr. Thanks for pushing me towards this. I want to keep contributing to the project.
Sure, Sorry for the miss, I'll try to upload the configs over the weekend and update the recipes as well.
Hey @Ankur-singh , apologies for the delay. I have updated the recipies as well as added the config.yaml and tune run command. Please do let me know if i missed anything else.
Can you rebase on main?
Can you rebase on main?
Hi Joe.. apologies for the delay, I have rebased on the main. Please let me know if any other changes are needed. Thanks in advance
Codecov Report
Attention: Patch coverage is 0% with 62 lines in your changes missing coverage. Please review.
Project coverage is 64.73%. Comparing base (
f3e4747) to head (3220b41). Report is 5 commits behind head on main.
| Files with missing lines | Patch % | Lines |
|---|---|---|
| recipes/knowledge_distillation_distributed.py | 0.00% | 62 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #2395 +/- ##
==========================================
- Coverage 65.78% 64.73% -1.06%
==========================================
Files 396 395 -1
Lines 23764 23560 -204
==========================================
- Hits 15634 15251 -383
- Misses 8130 8309 +179
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
This looks like it's ready to land once we get the tests passing. Can you please merge main and resolve any conflicts and verify all tests pass?