torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

[RFC] Integration with DCP - Benchmark Results

Open LucasLLC opened this issue 11 months ago • 1 comments

Context

We've pivoted on this PR from an RFC on torchtune + dcp UX, to a PR which showcases async saving benchmarks. If consensus is reached that we want an async_checkpointing integration in torchtune, I'll pull the latest GPU recipes and implement this code in a less-hacky way there. Please don't review this PR for code cleanliness as it's a bit of a dumpster fire in the current state.

Results:

full_finetune:

world_size main (full state dict) dcp (sharded state dict)
2 44.33777084 10.303921
4 68.12809 8.397324
8 89.22356 2.483598

What's shown here is the blocking portion of saving -- or specifically the portion that blocks the training loop. LLM training is a great case for async checkpointing since the training loops are fairly long. This means by the time we reach the end of the epoch, serialization has finished and we never have to wait on the previous checkpoint before staging another one.

In the full_tune case, we will also notice an inverse relation to main in terms of ranks. As opposed to the current implementation, which increases latency with number of ranks, we expect the sharded state dict to be smaller per rank as we increase world_size, a fact that is taken advantage of in DCP. Meaning: "more GPUs, faster save"

lora_finetune:

world_size main (full state dict) dcp (sharded state dict)
2 0.736959 1.0394
4 3.119773 3.502441
8 0.3324818 1.03378

In the Lora case, we'll notice performance is actually at parity or worst for sharded state dict + DCP. There's a couple reasons for this:

  1. Most of the time spent saving the lora checkpoints is actually spent in FSDP's state_dict calls. Corresponding to the graph above, on the main branch with 2 ranks we see full state dict takes ~0.624 seconds. Sharded state dict can take slightly longer to materialize.

  2. The file sizes are really small, looks like 16MB - 48MB. Essentially async overhead will be almost equivalent to the serialization time of just using torch.save ( at least that's what I'm seeing in this case)

Proposal:

For intermediary checkpoints on multi-gpu recipes, we should implement DCP.async_save as the checkpointing solution.

Some decisions points (comments are welcome):

  1. We can go about converting to an un-sharded torch.save file in two ways. We can either : a) Go back to the slower torch.save + full state dict implementation on the last call to save in the last epoch b) Save in the DCP format always, and ask users to convert themselves using torch.distributed.checkpoint.format_utils

  2. Using the DCP checkpointer for lora_finetune would keep things consistent, but could incur some (pretty minimal) performance hits. Any thoughts?

Changelog

  • ... WIP

Test plan

  • .... WIP

LucasLLC avatar Mar 04 '24 16:03 LucasLLC

Deploy Preview for torchtune-preview ready!

Name Link
Latest commit 1644396bbd6c1a10c82277b8de5e4e5f0b777ead
Latest deploy log https://app.netlify.com/sites/torchtune-preview/deploys/65ea660bc9f4f7000930dc8c
Deploy Preview https://deploy-preview-443--torchtune-preview.netlify.app
Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

netlify[bot] avatar Mar 04 '24 16:03 netlify[bot]