torchtune icon indicating copy to clipboard operation
torchtune copied to clipboard

Gradient accumulation is not efficiently implemented for distributed recipes

Open physicsrob opened this issue 1 year ago • 13 comments

For distributed recipes, such as full_finetune_distributed, the gradients end up getting synchronized after each backward() pass instead of only once before the optimizer step. This results in significant unnecessary communication overhead.

I noticed this when porting a training script from using huggingface accelerate to something based on the full_finetune_distributed recipe. For my use case (fine tuning Llama3 on 8 nodes with 4xA100-80 each node) I noticed my training time more than doubled.

Digging in, I found this to be the cause. I've fixed this in my recipe, but it seems like the fix should be applied to all distributed recipes.

I'd be happy to contribute a PR if it's agreed that it should be fixed.

For reference, the fix is roughly wrapping the forward/backward with:

            with nullcontext() if do_sync else self._model.no_sync():

where do_sync is True for the last step of the gradient accumulation.

(There's an additional bug where gradient accumulation runs after a single forward/backward for the first step, but that's fairly minor)

physicsrob avatar Aug 06 '24 20:08 physicsrob

cc @weifengpy

This seems pretty straightforward, thanks for catching this @physicsrob. We'd love to accept a PR to fix this :)

joecummings avatar Aug 06 '24 20:08 joecummings

look forward to the PR. I am actually curious which line of code triggered sync ?

weifengpy avatar Aug 06 '24 21:08 weifengpy

Also cc: @awgu since we had a chat about gradient accumulation and FSDP in the early days of torchtune

kartikayk avatar Aug 09 '24 21:08 kartikayk

For FSDP, there are two ways to accumulate gradients:

  1. Accumulate unsharded gradients (model.no_sync context in FSDP1, model.set_requires_gradient_sync(is_last_microbatch) in FSDP2)
  2. Accumulate sharded gradients

We should differentiate between training throughput and memory usage (both of which could be referred to by "efficiency").

The two options to accumulate gradients pose a direct tradeoff between communication time and memory usage.

  • Option 1 means no reduce-scatter in backward until the last microbatch backward, saving communication time but using more memory.
  • Option 2 means reduce-scatter in every microbatch backward, incurring more communication time but using less memory.

Note that extra communication time mainly translates to lower throughput when it cannot be fully overlapped, which depends on your inter-node bandwidth.

For Llama3-8B, the unsharded gradients take ~8B numel. Whether option 1 vs. 2 makes sense depends on how much memory you can afford to tradeoff and whether you want to accumulate gradients in fp32 or bf16.

  • In fp32, the unsharded gradients will take 32 GB, while in bf16, they will take 16 GB.
  • For pre-training, we find that gradient accumulation in fp32 is required for good convergence. For other workloads, maybe this can be relaxed (it also depends on how many microbatches you are running -- more generally means fp32 is preferred).

Note that with FSDP2, you can play around with this with some more granularity. For example, if you are "wrapping" / applying fully_shard to every transformer block, then you can selectively disable reduce-scatter for every other (or every k) transformer blocks since set_requires_gradient_sync(bool) is a method on FSDPModule (which each transformer block is when you call fully_shard on it). This can help overlap the fewer reduce-scatters.

awgu avatar Aug 09 '24 21:08 awgu

so in a no_sync context, gradients are accumulated in FP32?

ScottHoang avatar Aug 20 '24 16:08 ScottHoang

I think in FSDP1, they are accumulated in the dtype that the gradients were computed in (e.g. bf16). In FSDP2, if you specify MixedPrecisionPolicy(reduce_dtype=torch.float32), then it will have extra logic to accumulate the gradients in fp32; if you did not specify that higher precision reduce_dtype, then it will similarly accumulate in the dtype that the gradients were computed in.

awgu avatar Aug 20 '24 17:08 awgu

Extending on this (and maybe unrelated to the overall topic) In the current implementation of FSDP1, we are sharding parameters across nodes in a multi-node scenario (zero3 implementation). Is there a way to limit the sharding to be intra-node?

ScottHoang avatar Aug 20 '24 20:08 ScottHoang

You should be able to pass device_mesh as a 2D device mesh to enable HSDP. (You could also pass in a 2-tuple of process groups, but I think the checkpointing support is better for the device mesh path.)

You can create the device mesh via something like:

from torch.distributed import init_device_mesh

global_world_size = torch.distributed.get_world_size()  # global world size, assumes dist is initialized
intra_node_size = torch.cuda.device_count()  # e.g. 8
device_mesh = init_device_mesh("cuda", (global_world_size // intra_node_size, intra_node_size))  # (replicate, shard)

fsdp_model = FSDP(model, ..., device_mesh=device_mesh)

awgu avatar Aug 20 '24 20:08 awgu

So, in the case of a small model < 13B, and we want to scale it with multi nodes and increase throughput, it is better to use HSDP with device_mesh for intra-node fsdp and inter-node DDP ?

ScottHoang avatar Aug 20 '24 20:08 ScottHoang

It depends on your inter-node bandwidth. If your inter-node bandwidth is fast, FSDP is probably still better, especially if your model is compute-dense like a transformer.

The overall workflow I would follow is to get a profiler trace, look at if the communications are overlapping or not, and determine from there. If you can still overlap FSDP collectives in the multi-node setup, then you probably prefer FSDP because then you can save more memory and possibly reinvest that into batch size or decrease the amount of activation checkpointing.

If you have slow inter-node bandwidth though, then it is possible that there is a really large discrete jump in communication time when going from single-node to multi-node, in which case HSDP can help because you only have some all-reduce across nodes.

awgu avatar Aug 20 '24 20:08 awgu

This perfectly solved my problem! thank you!

ScottHoang avatar Aug 20 '24 20:08 ScottHoang

@awgu Actually, one last question: in Hybrid shard mode with multi-nodes, does "sync_module_states" still broadcast rank=0's params to rank on different nodes?

ScottHoang avatar Aug 21 '24 05:08 ScottHoang

@ScottHoang yes, it will broadcast from global rank 0 to all ranks (including both intra and inter-node process groups): https://github.com/pytorch/pytorch/blob/afaa5fcecb07472a8805902074f4611dc5798f76/torch/distributed/fsdp/_init_utils.py#L632-L635

awgu avatar Aug 21 '24 13:08 awgu