torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

"Universal" Checkpointing

Open jeromeku opened this issue 10 months ago • 4 comments

Is there an equivalent of Deepspeed Universal Checkpointing currently for distributed checkpointing, DTensor and FSDP2? That is, how to use torch-native tooling to convert from a checkpoint with a given sharded / parallelism config to a new config such that the sharded state dicts can be directly loaded with a new world size.

For example, train a model on 128 GPUs with FSDP (DP128) and save a checkpoint with 128 sharded state dicts. Resume training on 64 GPUs with TP2 / FSDP (DP32).

Manually, one could merge the original checkpoint from 128 shards -> single merged state dict, then reshard to TP2 followed by partitioning TP shards to 32 DP partitions for a total of 64 sharded state dicts, then directly load these state dicts on each rank (without having to first materialize the full state dict on any rank).

@awgu

jeromeku avatar Feb 17 '25 12:02 jeromeku

DCP today should already support "resharding", on most of the components in torchtitan, except for data loading (which you can bypass via checkpoint.exclude_from_loading). Have you tried and encountered problems? Can you share the errors?

tianyu-l avatar Feb 17 '25 20:02 tianyu-l

PTD DCP is designed to do online resharding for model and optimizer states. More specifically, if all the model parallelisms are PTD native (fully_shard, TP, PP), then the saved checkpoint can be loaded even if the world size or the parallelism scheme changes. However, all the ranks must be able to access all the files as DCP cannot predict what files are requires during the resharding.

fegin avatar Feb 20 '25 19:02 fegin

@fegin @tianyu-l

Thanks for the responses.

I'm trying to understand the structure of the files in a dcp checkpoint folder (metadata + per rank shards). Is there an easy way to examine the contents of these files -- either for debugging or to perform manual resharding?

The formats of these files are not directly readable using torch.load. Imagine this can be done by unpacking what dcp_to_torch_save is doing...

jeromeku avatar Feb 23 '25 16:02 jeromeku

The only thing you can do is to torch.load the .metadata. The actual data files are not unpickled without writing some code.

fegin avatar Feb 25 '25 07:02 fegin