accelerate icon indicating copy to clipboard operation
accelerate copied to clipboard

Plan to support FSDP2?

Open ByronHsu opened this issue 1 year ago • 11 comments

FSDP2 provides smaller memory footprint, compatibility with torch compile, and more flexibility due to per param sharding. Does huggingface have plan to support FSDP2?

https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md

ByronHsu avatar Jun 19 '24 22:06 ByronHsu

Thanks for bringing FSDP2 to our (or at least my) attention. The changes described in the document you linked sound very reasonable and could remove some of the common pain points of using FSDP.

Reading this, it got the impression that this is a very new addition to PyTorch. When searching for fully_shard in the PyTorch docs, there is no hit, which reinforces this impression. But looking at the actual code, it's already 2 years old! So I'm confused now about the state of this feature: Is it going to be officially released soon or is it more of an experimental feature that may or may not see continued work? Do you have any insights on that @ByronHsu?

BenjaminBossan avatar Jun 20 '24 08:06 BenjaminBossan

Thanks @BenjaminBossan! If I understand correctly, PyTorch team wants to replace FSDP1 with FSDP2 in the long term. I saw it has already been integrated in torchtitan. Maybe we can have some plans for accelerate too? Otherwise, users cannot use torch compile with FSDP in hf. cc PyTorch team @awgu @msaroufim

ByronHsu avatar Jun 20 '24 23:06 ByronHsu

But looking at the actual code, it's already 2 years old!

Very sorry for the confusion! There are two separate functions called fully_shard, one being 2 years old and one being new from this year. For historical context, we were experimenting with approaches to implementing FSDP that were not an nn.Module wrapper like FullyShardedDataParallel. This led to the distributed/_composable folder, and the APIs were all verbs, hence fully_shard. The original fully_shard called into the same underlying code as FullyShardedDataParallel. The new fully_shard (FSDP2) is a standalone implementation.

We proposed FSDP2 as prototype for 2.4 release, and we are investing in it heavily.

awgu avatar Jun 21 '24 00:06 awgu

Thanks a lot for clarifying my confusion. In that case, I think it makes sense to wait until FSDP2 is released and then run experiments with accelerate to see how it can be best supported.

BenjaminBossan avatar Jun 21 '24 08:06 BenjaminBossan

The main worry with FSDPv2 is if it's stable enough that it makes sense to include it in Accelerate. At the worst case, we can keep a draft PR open and/or an experimental feature (and advertise it as such).

So my main question is:

  • How stable is it already? What ETA is there for it to be considered "stable"?

I planned on looking into FSDP2 in the near future anyways, so I'm open to having some early-ish support in Accelerate for it as long as I can get a full grasp of how long into the development it is.

(We did something similar with PiPPy, so okay do so here too)

I know we need to do some heavy uprooting to add in custom process support into Accelerate, which I believe FSDP2 relies on if I'm not mistaken?

muellerzr avatar Jul 01 '24 19:07 muellerzr

What'd be helpful on my end is some bare-bones FSDP2 examples in PyTorch with how things are operating end-to-end

muellerzr avatar Jul 01 '24 19:07 muellerzr

Barebones example of fsdpv2 is available in https://github.com/pytorch/torchtitan.

raghukiran1224 avatar Jul 29 '24 16:07 raghukiran1224

Thanks @raghukiran1224 :) Yes indeed I plan on looking into these w/ some of the torch folks. It's in our close future to get something small going. (Probably highly experimental, since they're still not settled with things yet)

muellerzr avatar Jul 29 '24 16:07 muellerzr

We have been looking at this, will be happy to help in bringing in FSDP2 as experimental parallel to accelerate. RFC PR - https://github.com/huggingface/accelerate/pull/3231

cc: @raghukiran1224 @ashokponkumar @prjayach @awgu

kmehant avatar Nov 08 '24 19:11 kmehant

look forward to more generic N-D parallel (device mesh, TP, CP) support instead of fsdp2 only. I have implemented a simple AcceleratorNd by inheritance to support this, but I found that we need to change many internal code to cover more case:

  1. accelerator.gather and other distributed ops must be performed in data parallel group instead of default group
  2. dataloader/scheduler must use data parallel group rank instead of global rank
  3. save_pretrained don't support DTensor (use pytorch DCP in accelerate and transformers can resolve this)

FindDefinition avatar Dec 17 '24 11:12 FindDefinition

this can be closed as https://github.com/huggingface/accelerate/pull/3394 support this. @ByronHsu

cyr0930 avatar Apr 25 '25 03:04 cyr0930