FSDP hybrid shard should checkpoint in a single node
Description & Motivation
https://github.com/pytorch/pytorch/pull/104810 adds the recommendation that the save APIs should be called in a single node (shard_group).
https://github.com/pytorch/pytorch/issues/102904#issuecomment-1862892480 Also talks about this
Our logic doesn't do this and runs this code in all ranks.
Additional context
Lit-gpt uses hybrid sharding in pretrain/tinyllama.py but full checkpointing. I believe this feature request is only relevant for sharded pointing. @awaelchli Did you try it? Does sharded hybrid checkpointing work?
cc @borda @awaelchli @carmocca
Yes we should fix this. I have never had the chance to test this I think. While yes we can temporarily fix this simply calling the saver only in node 0, in general hybrid shard can span an arbitrary number of processes. We would need to get the right process group from the device mesh, pass it to the saver too and only call the saver from these ranks.
The feature is so useful. Is it supported now?
Also interested in this along with async checkpointing.