PiPPy icon indicating copy to clipboard operation
PiPPy copied to clipboard

FSDP+PP bug where reshard_after_forward must be true

Open wconstab opened this issue 9 months ago • 6 comments

https://github.com/pytorch/torchtitan/pull/161/files#diff-80b04fce2b861d9470c6160853441793678ca13904dae2a9b8b7145f29cd017aR269

image

IIRC @awgu mentioned there was an issue requiring this setting for the time being. Not sure why or if it has been fixed yet?

wconstab avatar May 02 '24 23:05 wconstab

This seems like an important / high(er) priority issue since FSDP + PP generally wants reshard_after_forward=False.

awgu avatar May 02 '24 23:05 awgu

I believe in old FSDP, where FSDP API is called on the whole model, reshard_after_forward can be automatically figured out (or at least there is a way to do so).

I don't know if the new FSDP still allow the API to be called on the whole model or not, if allowed, can it be investigated so that this burden is not on the user? After all, reshard_after_forward is sort of an internal thing that requires certain level of understanding from the user about some "corner" procedure of FSDP.

That said, following @awgu 's comment, should we just do:

if pp:
    reshard_after_forward = False
else:
    reshard_after_forward = <a condition>

kwen2501 avatar May 03 '24 22:05 kwen2501

reshard_after_forward=True == ShardingStrategy.FULL_SHARD == ZeRO-3 reshard_after_forward=False == ShardingStrategy.SHARD_GRAD_OP == ZeRO-2

It is still the same (cannot be automatically figured out -- only the root module auto changes to reshard_after_forward=False since it will all-gather immediately to begin backward anyway). I would not consider to be a "corner" procedure of FSDP though. This is an important choice that affects the algorithm used, so generally users are aware of this.

awgu avatar May 03 '24 22:05 awgu

By "corner" case, I refer to this line:

reshard_after_forward = layer_id < len(model.layers) - 1

As compared to actively choosing ZeRo-2 or ZeRo-3, I think the user is more saying: I want to use FSDP, but I also want slightly more perf since the last layer's backward will immediately start after its forward so please don't reshard it.

kwen2501 avatar May 03 '24 22:05 kwen2501

Said in a different way, if we already know that: zero-3 + zero-3 + ... + zero-2 is going to be a common pattern, can we package that as an offering to our user? Should that be considered a preferred implementation of zero-3(whole model)?

kwen2501 avatar May 03 '24 22:05 kwen2501

I see. I think since we do not know the execution order in general, we cannot do it easily in the FSDP API itself, which is a building block. Maybe a higher level API that knows how to call FSDP for some class of models could do it.

awgu avatar May 04 '24 10:05 awgu