[wip] context parallelism
Initial implementation of context parallelism in torchtune.
Initial test
tune run --nproc_per_node 8 full_finetune_distributed --config llama3/8B_full \
context_parallel_dim=4 metric_logger=torchtune.training.metric_logging.WandBLogger
metric_logger.project=context-parallel metric_logger.name=llama3-8b-cp4-dp2
Also confirmed that we can run 1M sequence length on a single node (will paste results in here shortly)
Still to test
Should test (a) equivalent loss curves and (b) requisite memory improvements on a long-context dataset for each of the below:
- [ ] CP only (given above)
- [ ] CP + TP
- Currently blocked until #2667 lands
- [ ] CP + DP shard
- [ ] CP + DP replicate
- [ ] Composability with flex attention
- Need to look at https://github.com/pytorch/pytorch/pull/151497 and https://github.com/pytorch/torchtitan/pull/1160 (thanks @XilunWu)
- [ ] Composability with activation checkpointing + offloading
- [ ] Composability with optimizer in backward
- [ ] Composability with fp8 (will this work?)
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2668
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 93085e36f067a9d195dd56f7a10621488255fe78 with merge base 0d906758cde5a4a705f8f545009132b867f28f80 ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
Codecov Report
Attention: Patch coverage is 14.86486% with 63 lines in your changes missing coverage. Please review.
Project coverage is 60.02%. Comparing base (
0d90675) to head (2c7ad46). Report is 2 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #2668 +/- ##
==========================================
+ Coverage 7.75% 60.02% +52.27%
==========================================
Files 376 437 +61
Lines 23117 26765 +3648
==========================================
+ Hits 1792 16066 +14274
+ Misses 21325 10699 -10626
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
:rocket: New features to boost your workflow:
- :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
@ebsmothers Why are we not sharding fsdp model on dp_shard_cp dimension instead of dp_shard like in torchtitan? We could fit a larger seq length and training could become faster.
Hi @SKRohit nice catch, we should be doing this. Would you be interested in opening a PR with the fix?
@ebsmothers sure. I am working on it. Will push a pr soon.