Andrew Gu
Andrew Gu
cc: @Chillee would be curious to hear your thoughts
We do not have a unit test that can capture the difference between fp32/bf16 vs. fp16 division factors yet. It might be simpler to test this when we do implement...
> > * Upon the second microbatch, rank 0 after its reduce-scatter will additionally have its shard of 18∑i∈S(0)gi(2). If we only all-reduce this, then this second microbatch's gradients become...
> > Sorry, the point of what we are trying to do is to _not_ all-reduce the first microbatch's gradients. This is to save communication. Just reduce-scattering is enough to...
@pytorchbot merge
This seems like an important / high(er) priority issue since FSDP + PP generally wants `reshard_after_forward=False`.
`reshard_after_forward=True` == `ShardingStrategy.FULL_SHARD` == ZeRO-3 `reshard_after_forward=False` == `ShardingStrategy.SHARD_GRAD_OP` == ZeRO-2 It is still the same (cannot be automatically figured out -- only the root module auto changes to `reshard_after_forward=False` since...
I see. I think since we do not know the execution order in general, we cannot do it easily in the FSDP API itself, which is a building block. Maybe...
The advantage of meta-device init is that it is as fast as possible: the _sharded_ parameters are directly initialized on _GPU_. Any other flow requires something more, e.g. (1) initializing...
@XinDongol A few clarifications: ``` model = model_cls.from_model_args(model_config) # (1) If the `model_cls.__init__` did not already call `init_weights()` or similar model.init_weights() # (2) Apply FSDP with multiple FSDP calls, e.g....