transformers icon indicating copy to clipboard operation
transformers copied to clipboard

Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer

Open JasonZhu1313 opened this issue 1 year ago • 6 comments
trafficstars

What does this PR do?

Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to HF Trainer with optional flag

Fixes # (issue)

Before submitting

  • [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [x] Did you read the contributor guideline, Pull Request section?
  • [x] Was this discussed/approved via a Github issue or the forum? Please add a link to it if that's the case.
  • [x] Did you make sure to update the documentation with your changes? Here are the documentation guidelines, and here are tips on formatting docstrings.
  • [x] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR.

Tests:

  • [x] pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_apply_liger_kernel
pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_use_liger_kernel_patching tests/trainer/test_trainer.py::TrainerIntegrationTest::test_use_liger_kernel_trainer

======================================= test session starts ========================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /content/transformers-jaszhu
configfile: pyproject.toml
plugins: rich-0.1.1, timeout-2.3.1, xdist-3.6.1
collected 2 items

tests/trainer/test_trainer.py ..                                                             [100%]

======================================== 2 passed in 9.47s =========================================


  • [x] E2E test
{'loss': 1.6157, 'grad_norm': 32.0, 'learning_rate': 2.4324324324324326e-07, 'epoch': 0.0, 'num_input_tokens_seen': 60416, 'step': 3, 'step_time_sec': 4.87, 'avg_step_time_sec': 6.82, 'time_to_completion_sec': 4970.4, 'estimated_total_time_sec': 4990.85, 'step_peak_memory_allocated_MB': 76728.45, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 79692.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3138.55, 'avg_tokens_per_second': 3158.65}
{'loss': 1.5678, 'grad_norm': 26.875, 'learning_rate': 3.2432432432432436e-07, 'epoch': 0.01, 'num_input_tokens_seen': 84992, 'step': 4, 'step_time_sec': 7.82, 'avg_step_time_sec': 7.15, 'time_to_completion_sec': 5206.53, 'estimated_total_time_sec': 5235.14, 'step_peak_memory_allocated_MB': 76728.67, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80194.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3142.99, 'avg_tokens_per_second': 3152.94}
{'loss': 1.74, 'grad_norm': 28.875, 'learning_rate': 4.0540540540540546e-07, 'epoch': 0.01, 'num_input_tokens_seen': 103936, 'step': 5, 'step_time_sec': 5.75, 'avg_step_time_sec': 6.8, 'time_to_completion_sec': 4945.07, 'estimated_total_time_sec': 4979.08, 'step_peak_memory_allocated_MB': 76728.54, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80324.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3293.14, 'avg_tokens_per_second': 3182.59}
{'loss': 1.7297, 'grad_norm': 29.25, 'learning_rate': 4.864864864864865e-07, 'epoch': 0.01, 'num_input_tokens_seen': 124416, 'step': 6, 'step_time_sec': 6.23, 'avg_step_time_sec': 6.69, 'time_to_completion_sec': 4855.78, 'estimated_total_time_sec': 4895.91, 'step_peak_memory_allocated_MB': 76728.57, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80288.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3285.22, 'avg_tokens_per_second': 3201.72}
{'loss': 1.6393, 'grad_norm': 27.75, 'learning_rate': 5.675675675675676e-07, 'epoch': 0.01, 'num_input_tokens_seen': 153920, 'step': 7, 'step_time_sec': 9.22, 'avg_step_time_sec': 7.11, 'time_to_completion_sec': 5154.73, 'estimated_total_time_sec': 5204.5, 'step_peak_memory_allocated_MB': 76728.78, 'total_peak_memory_allocated_MB': 76728.78, 'step_peak_memory_reserved_MB': 79652.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3200.77, 'avg_tokens_per_second': 3201.51}
{'loss': 1.5642, 'grad_norm': 27.25, 'learning_rate': 6.486486486486487e-07, 'epoch': 0.01, 'num_input_tokens_seen': 170752, 'step': 8, 'step_time_sec': 5.49, 'avg_step_time_sec': 6.88, 'time_to_completion_sec': 4980.15, 'estimated_total_time_sec': 5035.18, 'step_peak_memory_allocated_MB': 76728.49, 'total_peak_memory_allocated_MB': 76728.78, 'step_peak_memory_reserved_MB': 79988.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3065.48, 'avg_tokens_per_second': 3186.0}

  • [x] When liger is lower version, the error is thrown ImportError: You have set use_ligertoTruebut liger-kernel >= 0.1.0 is not available. Please install it withpip install liger-kernel`
  • [x] Model type is correct extracted as "llama"

Screenshot 2024-08-20 at 3 13 45 PM

Test conditions: LLaMA 3-8B, Batch Size = 64, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s.

When use_liger=Ture, memory usage and throughput shows improvement compared to use_liger=False, default value

image (3) image (4)

Note: for more detailed benchmark setup and more exciting efficiency for multi-head training (Medusa), please refer to original repo: https://github.com/linkedin/Liger-Kernel (repo will be public soon!!!)

JasonZhu1313 avatar Aug 17 '24 00:08 JasonZhu1313

cc @ArthurZucker @muellerzr

amyeroberts avatar Aug 19 '24 09:08 amyeroberts

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

lgtm!

ByronHsu avatar Aug 21 '24 04:08 ByronHsu

@JasonZhu1313 if you run make fixup it should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀

muellerzr avatar Aug 22 '24 12:08 muellerzr

@JasonZhu1313 if you run make fixup it should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀

Thanks the repo will be open sourced on Friday

JasonZhu1313 avatar Aug 22 '24 15:08 JasonZhu1313

@JasonZhu1313 if you run make fixup it should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀

Thanks the repo will be open sourced on Friday

The code is open to public, we are ready to merge the PR!

JasonZhu1313 avatar Aug 22 '24 20:08 JasonZhu1313

Nice ! Merging !

SunMarc avatar Aug 23 '24 11:08 SunMarc