torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[wip] context parallelism

Open ebsmothers opened this issue 11 months ago • 1 comments

Initial implementation of context parallelism in torchtune.

Initial test

tune run --nproc_per_node 8 full_finetune_distributed --config llama3/8B_full \
context_parallel_dim=4 metric_logger=torchtune.training.metric_logging.WandBLogger 
metric_logger.project=context-parallel metric_logger.name=llama3-8b-cp4-dp2
Screenshot 2025-05-02 at 4 08 51 PM

Also confirmed that we can run 1M sequence length on a single node (will paste results in here shortly)

Still to test

Should test (a) equivalent loss curves and (b) requisite memory improvements on a long-context dataset for each of the below:

  • [ ] CP only (given above)
  • [ ] CP + TP
    • Currently blocked until #2667 lands
  • [ ] CP + DP shard
  • [ ] CP + DP replicate
  • [ ] Composability with flex attention
    • Need to look at https://github.com/pytorch/pytorch/pull/151497 and https://github.com/pytorch/torchtitan/pull/1160 (thanks @XilunWu)
  • [ ] Composability with activation checkpointing + offloading
  • [ ] Composability with optimizer in backward
  • [ ] Composability with fp8 (will this work?)

ebsmothers avatar May 02 '25 23:05 ebsmothers

:link: Helpful Links

:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2668

Note: Links to docs will display an error until the docs builds have been completed.

:white_check_mark: No Failures

As of commit 93085e36f067a9d195dd56f7a10621488255fe78 with merge base 0d906758cde5a4a705f8f545009132b867f28f80 (image): :green_heart: Looks good so far! There are no failures yet. :green_heart:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pytorch-bot[bot] avatar May 02 '25 23:05 pytorch-bot[bot]

Codecov Report

Attention: Patch coverage is 14.86486% with 63 lines in your changes missing coverage. Please review.

Project coverage is 60.02%. Comparing base (0d90675) to head (2c7ad46). Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/_distributed.py 20.00% 40 Missing :warning:
recipes/full_finetune_distributed.py 0.00% 13 Missing :warning:
recipes/qat_distributed.py 0.00% 10 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##            main    #2668       +/-   ##
==========================================
+ Coverage   7.75%   60.02%   +52.27%     
==========================================
  Files        376      437       +61     
  Lines      23117    26765     +3648     
==========================================
+ Hits        1792    16066    +14274     
+ Misses     21325    10699    -10626     

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov-commenter avatar May 28 '25 23:05 codecov-commenter

@ebsmothers Why are we not sharding fsdp model on dp_shard_cp dimension instead of dp_shard like in torchtitan? We could fit a larger seq length and training could become faster.

SKRohit avatar Jul 10 '25 12:07 SKRohit

Hi @SKRohit nice catch, we should be doing this. Would you be interested in opening a PR with the fix?

ebsmothers avatar Jul 10 '25 13:07 ebsmothers

@ebsmothers sure. I am working on it. Will push a pr soon.

SKRohit avatar Jul 10 '25 14:07 SKRohit