torchtune
torchtune copied to clipboard
FSDP Llama3 wrapping improvements for full finetune
Context
What is the purpose of this PR? Is it to
- [x] add a new feature
- [ ] fix a bug
- [x] update tests and/or documentation
- [ ] other (please add here)
Changelog
This PR primarily seeks to improve memory efficiency specifically for llama3 full distributed training and enable a distributed finetune in 4x 24GB of memory. We do this with a new FSDP wrapping policy that wraps the token embedding and output projections. These are much larger for llama3 due to the increased vocab size, so sharding them across GPUs has more of an effect.
- Added new API to retrieve memory efficient FSDP wrapping policy. To control whether the memory efficient wrapping policy is retrieved, we introduce a new flag
memory_efficient_wrapping
in our configs. Currently, this is only set to True for llama3 distributed full finetuning. As follow up work, we'll investigate other workloads with this wrapping and enable where beneficial. - Added appropriate unittests
- Integrated in
full_finetune_distributed
. Did a bit of study on potential integration into LoRA, but memory savings were less pronounced there - this needs further investigations.
Test plan
- Added unittests.
This PR only seeks to ship improvements to llama3 training.
Docs
Full finetune
Run command for full finetune: CUDA_VISIBLE_DEVICES=0,3,6,7 tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full batch_size=1
- With this PR:
peak_memory_active:20.830272 peak_memory_alloc:19.085376 peak_memory_reserved:23.699914752
, 1.06it/s - Without this PR:
peak_memory_active:24.170446336 peak_memory_alloc:21.988057088 peak_memory_reserved:27.908898816
, 1.08it/s - About 13% savings in allocated memory, 15% in memory resereved. This allows us to get a 4x 24GB finetune.
- NOTE: A previous version of this PR also wrapped the token embedding and output projection in their own activation checkpointing units, but this is not needed. Vocabulary size is increased, but activations generated are proportional to sequences, not vocab size, so checkpointing these won't help more for llama3 compared to llama2. A quick study checkpointing these versus not shows roughly the same memory efficiency. In particular, with checkpointing the token embedding and output proj, we achieve
peak_memory_active:20.880037376 peak_memory_alloc:19.135141376 peak_memory_reserved:24.13821952
, while without it, we achieve the numbers reported above: they are very comparable.
:link: Helpful Links
:test_tube: See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/865
- :page_facing_up: Preview Python docs built from this PR
Note: Links to docs will display an error until the docs builds have been completed.
:white_check_mark: No Failures
As of commit 6124dd5a0b93e2db59c960fc73e51ae345060e36 with merge base 7d05579fe50bc956b948b66ea9b1ecdf2fe7da8b ():
:green_heart: Looks good so far! There are no failures yet. :green_heart:
This comment was automatically generated by Dr. CI and updates every 15 minutes.
CUDA_VISIBLE_DEVICES=0,3,6,7
👀
Overall this makes sense to me and the memory savings for full finetune are great. My main question is around whether model type is the most natural way to expose this feature.
There's nothing about this functionality that is unique to Llama3, it's just that it proves most beneficial there. By doing things this way we are kinda making the decision for people that only Llama3 should have this feature, and supporting other models with large vocab sizes will then require updating the wrapping internals instead of just flipping a config. I know we had discussed Gemma as one specific model where this is a challenge, but I wonder if we can do an assert on the backend to raise an error if the model class is not
TransformerDecoder
as we would expect.
@ebsmothers I definitely agree here. My proposal on a way forward would be to decouple the model type from the checkpointer and offer it as a general accessor to determine which model is being trained - there's currently no easy way to go about this. And I'd like for this change to be especially focused on llama3 (so the initial rollout of these policies will only be done for llama3). As follow up work we should enable for llama2 and investigate other models, although verifying these improvements are currently a long-running process and should ideally be done iteratively and/or by multiple folks, IMO
@joecummings could you chime in on ModelType for this sort of use case and if you happen to have any, alternative ways to achieve this sort of gating based on specific models here?
@rohan-varma I could totally be missing something here, but why can't we include embedding
in the modules to wrap within the config for Llama3, rather than tie this directly to ModelType
? That way, you can expand this to any new models that have this large embedding space, which is starting to become more popular.
cc: @ebsmothers
I haven't tested this but will this allow 70B on 8x80GB? I was only able to full-fine tune 70B with cpu offloading
@rohan-varma I could totally be missing something here, but why can't we include
embedding
in the modules to wrap within the config for Llama3, rather than tie this directly toModelType
? That way, you can expand this to any new models that have this large embedding space, which is starting to become more popular.cc: @ebsmothers
@joecummings This is because modules_to_wrap
is not configurable right now, and configuring it would be a little tricky (i.e. we'd have to parse the string like torch.nn.Embedding
and make it a class)
I haven't tested this but will this allow 70B on 8x80GB? I was only able to full-fine tune 70B with cpu offloading
This unfortunately won't allow 70B on 8x80GB from my experiments without CPU offloading, but can do a bit more testing. our current thinking is to enable full finetune for 70B models with CPU offload.
Thanks for adding this!
I don't think I fully understand this:
New AC wrapping policy that checkpoints the token embedding and output projections as well. Similar reason to above - they generate larger activations so it would be useful to not store those in memory.
Irrespective of the size of the vocab, the output of the embedding table would just depend on the sequence length? So why does this have anything to do with the vocab size? Or am I misunderstanding?
@kartikayk Thanks for the feedback and the review! You're totally right that this doesn' t have anything to do with the vocab size and this was an oversight on my end. I verified that if we remove the modified AC wrapping, we don't change anything about the memory improvements we're shippping here. So this PR is now only limited to FSDP wrapping changes.
Also added a bunch more documentation to FSDPPolicyType to clearly explain it to the user and link back to FSDP wrapping docs where useful. thanks!
Codecov Report
Attention: Patch coverage is 24.32432%
with 28 lines
in your changes are missing coverage. Please review.
Project coverage is 26.67%. Comparing base (
a978956
) to head (6124dd5
). Report is 29 commits behind head on main.
Files | Patch % | Lines |
---|---|---|
tests/torchtune/utils/test_distributed.py | 21.05% | 15 Missing :warning: |
torchtune/utils/_distributed.py | 27.77% | 13 Missing :warning: |
Additional details and impacted files
@@ Coverage Diff @@
## main #865 +/- ##
===========================================
- Coverage 66.39% 26.67% -39.72%
===========================================
Files 155 172 +17
Lines 6484 7182 +698
===========================================
- Hits 4305 1916 -2389
- Misses 2179 5266 +3087
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.