transformers
transformers copied to clipboard
Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer
What does this PR do?
Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to HF Trainer with optional flag
Fixes # (issue)
Before submitting
- [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [x] Did you read the contributor guideline, Pull Request section?
- [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
- [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
- [x] Did you write any new necessary tests?
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.
Tests:
- [x] pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_apply_liger_kernel
pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_use_liger_kernel_patching tests/trainer/test_trainer.py::TrainerIntegrationTest::test_use_liger_kernel_trainer
======================================= test session starts ========================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /content/transformers-jaszhu
configfile: pyproject.toml
plugins: rich-0.1.1, timeout-2.3.1, xdist-3.6.1
collected 2 items
tests/trainer/test_trainer.py .. [100%]
======================================== 2 passed in 9.47s =========================================
- [x] E2E test
{'loss': 1.6157, 'grad_norm': 32.0, 'learning_rate': 2.4324324324324326e-07, 'epoch': 0.0, 'num_input_tokens_seen': 60416, 'step': 3, 'step_time_sec': 4.87, 'avg_step_time_sec': 6.82, 'time_to_completion_sec': 4970.4, 'estimated_total_time_sec': 4990.85, 'step_peak_memory_allocated_MB': 76728.45, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 79692.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3138.55, 'avg_tokens_per_second': 3158.65}
{'loss': 1.5678, 'grad_norm': 26.875, 'learning_rate': 3.2432432432432436e-07, 'epoch': 0.01, 'num_input_tokens_seen': 84992, 'step': 4, 'step_time_sec': 7.82, 'avg_step_time_sec': 7.15, 'time_to_completion_sec': 5206.53, 'estimated_total_time_sec': 5235.14, 'step_peak_memory_allocated_MB': 76728.67, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80194.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3142.99, 'avg_tokens_per_second': 3152.94}
{'loss': 1.74, 'grad_norm': 28.875, 'learning_rate': 4.0540540540540546e-07, 'epoch': 0.01, 'num_input_tokens_seen': 103936, 'step': 5, 'step_time_sec': 5.75, 'avg_step_time_sec': 6.8, 'time_to_completion_sec': 4945.07, 'estimated_total_time_sec': 4979.08, 'step_peak_memory_allocated_MB': 76728.54, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80324.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3293.14, 'avg_tokens_per_second': 3182.59}
{'loss': 1.7297, 'grad_norm': 29.25, 'learning_rate': 4.864864864864865e-07, 'epoch': 0.01, 'num_input_tokens_seen': 124416, 'step': 6, 'step_time_sec': 6.23, 'avg_step_time_sec': 6.69, 'time_to_completion_sec': 4855.78, 'estimated_total_time_sec': 4895.91, 'step_peak_memory_allocated_MB': 76728.57, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80288.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3285.22, 'avg_tokens_per_second': 3201.72}
{'loss': 1.6393, 'grad_norm': 27.75, 'learning_rate': 5.675675675675676e-07, 'epoch': 0.01, 'num_input_tokens_seen': 153920, 'step': 7, 'step_time_sec': 9.22, 'avg_step_time_sec': 7.11, 'time_to_completion_sec': 5154.73, 'estimated_total_time_sec': 5204.5, 'step_peak_memory_allocated_MB': 76728.78, 'total_peak_memory_allocated_MB': 76728.78, 'step_peak_memory_reserved_MB': 79652.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3200.77, 'avg_tokens_per_second': 3201.51}
{'loss': 1.5642, 'grad_norm': 27.25, 'learning_rate': 6.486486486486487e-07, 'epoch': 0.01, 'num_input_tokens_seen': 170752, 'step': 8, 'step_time_sec': 5.49, 'avg_step_time_sec': 6.88, 'time_to_completion_sec': 4980.15, 'estimated_total_time_sec': 5035.18, 'step_peak_memory_allocated_MB': 76728.49, 'total_peak_memory_allocated_MB': 76728.78, 'step_peak_memory_reserved_MB': 79988.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3065.48, 'avg_tokens_per_second': 3186.0}
- [x] When liger is lower version, the error is thrown
ImportError: You have setuse_ligertoTruebut liger-kernel >= 0.1.0 is not available. Please install it withpip install liger-kernel` - [x] Model type is correct extracted as "llama"
Test conditions: LLaMA 3-8B, Batch Size = 64, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s.
When use_liger=Ture, memory usage and throughput shows improvement compared to use_liger=False, default value
Note: for more detailed benchmark setup and more exciting efficiency for multi-head training (Medusa), please refer to original repo: https://github.com/linkedin/Liger-Kernel (repo will be public soon!!!)
cc @ArthurZucker @muellerzr
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.
lgtm!
@JasonZhu1313 if you run make fixup it should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀
@JasonZhu1313 if you run
make fixupit should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀
Thanks the repo will be open sourced on Friday
@JasonZhu1313 if you run
make fixupit should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀Thanks the repo will be open sourced on Friday
The code is open to public, we are ready to merge the PR!
Nice ! Merging !