torchtune
torchtune copied to clipboard
[RFC] Integration with DCP - Benchmark Results
Context
We've pivoted on this PR from an RFC on torchtune + dcp UX, to a PR which showcases async saving benchmarks.
If consensus is reached that we want an async_checkpointing
integration in torchtune, I'll pull the latest GPU recipes and implement this code in a less-hacky way there. Please don't review this PR for code cleanliness as it's a bit of a dumpster fire in the current state.
Results:
full_finetune:
world_size | main (full state dict) | dcp (sharded state dict) |
---|---|---|
2 | 44.33777084 | 10.303921 |
4 | 68.12809 | 8.397324 |
8 | 89.22356 | 2.483598 |
What's shown here is the blocking portion of saving -- or specifically the portion that blocks the training loop. LLM training is a great case for async checkpointing since the training loops are fairly long. This means by the time we reach the end of the epoch, serialization has finished and we never have to wait on the previous checkpoint before staging another one.
In the full_tune
case, we will also notice an inverse relation to main in terms of ranks. As opposed to the current implementation, which increases latency with number of ranks, we expect the sharded state dict to be smaller per rank as we increase world_size, a fact that is taken advantage of in DCP. Meaning: "more GPUs, faster save"
lora_finetune:
world_size | main (full state dict) | dcp (sharded state dict) |
---|---|---|
2 | 0.736959 | 1.0394 |
4 | 3.119773 | 3.502441 |
8 | 0.3324818 | 1.03378 |
In the Lora case, we'll notice performance is actually at parity or worst for sharded state dict + DCP. There's a couple reasons for this:
-
Most of the time spent saving the lora checkpoints is actually spent in FSDP's
state_dict
calls. Corresponding to the graph above, on the main branch with 2 ranks we see full state dict takes ~0.624
seconds. Sharded state dict can take slightly longer to materialize. -
The file sizes are really small, looks like 16MB - 48MB. Essentially async overhead will be almost equivalent to the serialization time of just using
torch.save
( at least that's what I'm seeing in this case)
Proposal:
For intermediary checkpoints on multi-gpu recipes, we should implement DCP.async_save as the checkpointing solution.
Some decisions points (comments are welcome):
-
We can go about converting to an un-sharded
torch.save
file in two ways. We can either : a) Go back to the slowertorch.save
+ full state dict implementation on the last call to save in the last epoch b) Save in the DCP format always, and ask users to convert themselves usingtorch.distributed.checkpoint.format_utils
-
Using the DCP checkpointer for
lora_finetune
would keep things consistent, but could incur some (pretty minimal) performance hits. Any thoughts?
Changelog
- ... WIP
Test plan
- .... WIP
Deploy Preview for torchtune-preview ready!
Name | Link |
---|---|
Latest commit | 1644396bbd6c1a10c82277b8de5e4e5f0b777ead |
Latest deploy log | https://app.netlify.com/sites/torchtune-preview/deploys/65ea660bc9f4f7000930dc8c |
Deploy Preview | https://deploy-preview-443--torchtune-preview.netlify.app |
Preview on mobile | Toggle QR Code...Use your smartphone camera to open QR code link. |
To edit notification comments on pull requests, go to your Netlify site configuration.