why not using set_requires_all_reduce for hsdp
It seems that the hsdp (using dp_replicate) and the grad accumulation logic is not optimal.
It seems that the global all reduce is triggered on each grad accumulation step whereas it should only be triggered in the last one. Adding model.set_requires_all_reduce(micro_step == len(micro_batches) - 1) https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_requires_all_reduce would allow to trigger the all reduce only in the last step.
My guess is that in practice nobody is doing more than one grad accumulation when using the dp_replicate, but still would be nice to add to be feature complete imo
oh, good point.
For pure FSDP, there's a tradeoff of whether to perform reduce scatter only once or every grad accumulation step. But for HSDP, it seems there's no benefit of performing the all-reduce every step.
But like you said
in practice nobody is doing more than one grad accumulation when using the dp_replicate
The point of grad accumulation is to save memory to fit in training (with smaller batch per step and more steps, hence slower); the point of HSDP is to sacrifice memory to trade for throughput. They are quite the opposite?
If we could argue that no one should ever do grad accumulation + HSDP, then the optimization might be a fake problem.
The point of grad accumulation is to save memory to fit in training (with smaller batch per step and more steps, hence slower); the point of HSDP is to sacrifice memory to trade for throughput. They are quite the opposite?
I think there is a bit more to this story. HSDP with large batch size can also help with cluster with bad interconnect across nodes. Also people might be using torchtitan to run research ablation around batch size in which case they might want to artificially increase batch size. Also when doing RL (if someone use torchtitan as backend) you have very large batch size.
My take is that its worth adding as its a one line change that doesn't do anything when there is only one grad acc in most usual setup but help in other. Also I believe people look at torchtitan for having the best reference on how to use fsdp2 so defo nice to showcase all the feature :)
I am happy to draft a PR if this sounds good to you.
HSDP with large batch size can also help with cluster with bad interconnect across nodes.
This makes sense. What's related is that HSDP can be used in torchft fault-tolerant training.
Also people might be using torchtitan to run research ablation around batch size in which case they might want to artificially increase batch size. Also when doing RL (if someone use torchtitan as backend) you have very large batch size.
For these two, why can't they just use FSDP?
I am happy to draft a PR if this sounds good to you.
Sure, please. Thank you!
By the way, unlike set_requires_gradient_sync, set_requires_all_reduce does not incur an additional memory burden.
@tianyu-l