torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

why not using set_requires_all_reduce for hsdp

Open samsja opened this issue 2 months ago • 4 comments

It seems that the hsdp (using dp_replicate) and the grad accumulation logic is not optimal.

It seems that the global all reduce is triggered on each grad accumulation step whereas it should only be triggered in the last one. Adding model.set_requires_all_reduce(micro_step == len(micro_batches) - 1) https://docs.pytorch.org/docs/stable/distributed.fsdp.fully_shard.html#torch.distributed.fsdp.FSDPModule.set_requires_all_reduce would allow to trigger the all reduce only in the last step.

My guess is that in practice nobody is doing more than one grad accumulation when using the dp_replicate, but still would be nice to add to be feature complete imo

samsja avatar Oct 22 '25 15:10 samsja

oh, good point.

For pure FSDP, there's a tradeoff of whether to perform reduce scatter only once or every grad accumulation step. But for HSDP, it seems there's no benefit of performing the all-reduce every step.

But like you said

in practice nobody is doing more than one grad accumulation when using the dp_replicate

The point of grad accumulation is to save memory to fit in training (with smaller batch per step and more steps, hence slower); the point of HSDP is to sacrifice memory to trade for throughput. They are quite the opposite?

If we could argue that no one should ever do grad accumulation + HSDP, then the optimization might be a fake problem.

tianyu-l avatar Oct 23 '25 08:10 tianyu-l

The point of grad accumulation is to save memory to fit in training (with smaller batch per step and more steps, hence slower); the point of HSDP is to sacrifice memory to trade for throughput. They are quite the opposite?

I think there is a bit more to this story. HSDP with large batch size can also help with cluster with bad interconnect across nodes. Also people might be using torchtitan to run research ablation around batch size in which case they might want to artificially increase batch size. Also when doing RL (if someone use torchtitan as backend) you have very large batch size.

My take is that its worth adding as its a one line change that doesn't do anything when there is only one grad acc in most usual setup but help in other. Also I believe people look at torchtitan for having the best reference on how to use fsdp2 so defo nice to showcase all the feature :)

I am happy to draft a PR if this sounds good to you.

samsja avatar Oct 23 '25 15:10 samsja

HSDP with large batch size can also help with cluster with bad interconnect across nodes.

This makes sense. What's related is that HSDP can be used in torchft fault-tolerant training.

Also people might be using torchtitan to run research ablation around batch size in which case they might want to artificially increase batch size. Also when doing RL (if someone use torchtitan as backend) you have very large batch size.

For these two, why can't they just use FSDP?

I am happy to draft a PR if this sounds good to you.

Sure, please. Thank you!

tianyu-l avatar Oct 24 '25 18:10 tianyu-l

By the way, unlike set_requires_gradient_sync, set_requires_all_reduce does not incur an additional memory burden. @tianyu-l

EquationWalker avatar Nov 12 '25 07:11 EquationWalker