pytorch-lightning icon indicating copy to clipboard operation
pytorch-lightning copied to clipboard

FSDP hybrid shard should checkpoint in a single node

Open carmocca opened this issue 1 year ago • 4 comments

Description & Motivation

https://github.com/pytorch/pytorch/pull/104810 adds the recommendation that the save APIs should be called in a single node (shard_group).

https://github.com/pytorch/pytorch/issues/102904#issuecomment-1862892480 Also talks about this

Our logic doesn't do this and runs this code in all ranks.

Additional context

Lit-gpt uses hybrid sharding in pretrain/tinyllama.py but full checkpointing. I believe this feature request is only relevant for sharded pointing. @awaelchli Did you try it? Does sharded hybrid checkpointing work?

cc @borda @awaelchli @carmocca

carmocca avatar Feb 17 '24 14:02 carmocca

Yes we should fix this. I have never had the chance to test this I think. While yes we can temporarily fix this simply calling the saver only in node 0, in general hybrid shard can span an arbitrary number of processes. We would need to get the right process group from the device mesh, pass it to the saver too and only call the saver from these ranks.

awaelchli avatar Feb 20 '24 15:02 awaelchli

The feature is so useful. Is it supported now?

qsh-zh avatar Mar 17 '24 16:03 qsh-zh

Also interested in this along with async checkpointing.

zaptrem avatar Jul 31 '24 05:07 zaptrem