PiPPy
PiPPy copied to clipboard
FSDP+PP bug where reshard_after_forward must be true
https://github.com/pytorch/torchtitan/pull/161/files#diff-80b04fce2b861d9470c6160853441793678ca13904dae2a9b8b7145f29cd017aR269
IIRC @awgu mentioned there was an issue requiring this setting for the time being. Not sure why or if it has been fixed yet?
This seems like an important / high(er) priority issue since FSDP + PP generally wants reshard_after_forward=False
.
I believe in old FSDP, where FSDP API is called on the whole model, reshard_after_forward
can be automatically figured out (or at least there is a way to do so).
I don't know if the new FSDP still allow the API to be called on the whole model or not, if allowed, can it be investigated so that this burden is not on the user? After all, reshard_after_forward
is sort of an internal thing that requires certain level of understanding from the user about some "corner" procedure of FSDP.
That said, following @awgu 's comment, should we just do:
if pp:
reshard_after_forward = False
else:
reshard_after_forward = <a condition>
reshard_after_forward=True
== ShardingStrategy.FULL_SHARD
== ZeRO-3
reshard_after_forward=False
== ShardingStrategy.SHARD_GRAD_OP
== ZeRO-2
It is still the same (cannot be automatically figured out -- only the root module auto changes to reshard_after_forward=False
since it will all-gather immediately to begin backward anyway). I would not consider to be a "corner" procedure of FSDP though. This is an important choice that affects the algorithm used, so generally users are aware of this.
By "corner" case, I refer to this line:
reshard_after_forward = layer_id < len(model.layers) - 1
As compared to actively choosing ZeRo-2 or ZeRo-3, I think the user is more saying: I want to use FSDP, but I also want slightly more perf since the last layer's backward will immediately start after its forward so please don't reshard it.
Said in a different way, if we already know that: zero-3 + zero-3 + ... + zero-2 is going to be a common pattern, can we package that as an offering to our user? Should that be considered a preferred implementation of zero-3(whole model)?
I see. I think since we do not know the execution order in general, we cannot do it easily in the FSDP API itself, which is a building block. Maybe a higher level API that knows how to call FSDP for some class of models could do it.