torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

[Feature] Add gradient accumulation

Open XinDongol opened this issue 1 year ago • 20 comments

Gradient accumulation (micro step) could be very useful when we want to have large batch size but with limited number of gpus.

XinDongol avatar May 01 '24 05:05 XinDongol

@XinDongol do you mean microbatching or pipeline parallel?

wanchaol avatar May 01 '24 18:05 wanchaol

@awgu - is there a context manager or similar option in fsdp2 that would support gradient accumulation and thus enable this in titan? I know we talked about this for HSDP but not sure about generic FSDP2.

lessw2020 avatar May 01 '24 19:05 lessw2020

I am guessing this is asking for normal microbatching. There are similar APIs for FSDP2 that can control communication during gradient accumulation.

We migrated the no_sync() context to directly just module.set_requires_gradient_sync(bool) so that it can be just placed at the top of the training loop as module.set_requires_gradient_sync(is_last_microbatch). Note however though, that typically for memory constrained cases, we prefer to just proceed as normal and reduce-scatter every microbatch.

awgu avatar May 01 '24 20:05 awgu

Thanks for updating. @wanchaol Yes, I am talking about microbatching.

https://github.com/pytorch/torchtitan/blob/58b11693507bc16e7df4618455ebe66e8094f71d/train.py#L291-L294

@awgu is it sufficient to change ? Thanks from (current)

with loss_parallel_ctx():
    pred = model(input_ids)
    loss = loss_fn(pred, labels)
    loss.backward()

to

for microbatch_idx in range(microbatch):
	batch = next(data_iterator)
    input_ids, labels = batch
	model.set_requires_gradient_sync(microbatch_idx==(microbatch-1))
	with loss_parallel_ctx():
	    pred = model(input_ids)
	    loss = loss_fn(pred, labels) / microbatch
	    loss.backward()

XinDongol avatar May 01 '24 22:05 XinDongol

@XinDongol I think that is sufficient.

If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload.

If you want to still reduce-scatter in backward, you can simply remove that model.set_requires_gradient_sync line (effectively leaving it as the default of True).

awgu avatar May 01 '24 22:05 awgu

@XinDongol I think that is sufficient.

If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload.

If you want to still reduce-scatter in backward, you can simply remove that model.set_requires_gradient_sync line (effectively leaving it as the default of True).

Thanks @awgu @XinDongol. Very helpful discussion. If model.set_requires_gradient_sync is always set True, is that equivalent to just do normal training without gradient accumulation? Like in below?

for microbatch_idx in range(microbatch):
    batch = next(data_iterator)
    input_ids, labels = batch
    with loss_parallel_ctx():
	pred = model(input_ids)
	loss = loss_fn(pred, labels) / microbatch
	loss.backward()

Is there a way to accumulate the gradient by keeping a running sum, and just do loss.backward() after finishing all the microbatch?

dreasysnail avatar Jun 22 '24 00:06 dreasysnail

Is there a way to accumulate the gradient by keeping a running sum, and just do loss.backward() after finishing all the microbatch?

What is the advantage of doing this? When you run a microbatch forward, the autograd graph associated with it (e.g. activations) will be kept alive until you run the corresponding microbatch backward. If you run all of your microbatch forward before a microbatch backward, then your memory cost will be similar to running the entire batch in one forward.

awgu avatar Jun 24 '24 13:06 awgu

If you want to avoid reduce-scatter in backward, then what you have is right. Note however that this will mean that gradients are left as unsharded through backward, which may use too much memory depending on the workload.

If you want to still reduce-scatter in backward, you can simply remove that model.set_requires_gradient_sync line (effectively leaving it as the default of True).

Hi, could you elaborate a bit more about "reduce-scatter in backward", and why setting "model.set_requires_gradient_sync" to True or False can affect its behavior, and what effect it will have on system performance (memory, communication, etc)?

Thank you very much for your help! @awgu

LeoXinhaoLee avatar Sep 09 '24 21:09 LeoXinhaoLee

Each data parallel worker will compute unsharded gradients (i.e. gradients with the original shape / no FSDP sharding) for its local batch in the backward pass. During that backward pass, we can then decide to either "synchronize" them across all data parallel workers via reduce-scatter or not. The only constraint is that by the last microbatch backward, we need to have done this once. If we have k microbatches, then we have flexibility for the preceding k-1 microbatches whether we want to incur this reduce-scatter or not.

If you do choose to reduce-scatter, then you are converting these unsharded gradients into sharded gradients, which take 1/N the amount of memory, where N is the FSDP world size (i.e. number of data parallel workers). However, it does require that you issue this collective communication (reduce-scatter), which may be exposed (not overlapped) e.g. if the network bandwidth is low. If you do not choose to reduce-scatter, then you save the unsharded gradients across those k-1 microbatches, which probably contributes to peak memory. However, you avoid the collective communication, which may give higher throughput.

awgu avatar Sep 09 '24 21:09 awgu

Thank you very much for the informative answer!

For the "sync after every micro-batch" strategy (high comm, low peak mem), I'm wondering if it's possible to further reduce memory by reduce-scatter gradient once it's computed at each layer (from deeper to shallower ones) during backward, instead of reduce-scatter after gradient of all layers are computed.

This will still have the same comm cost as reduce-scatter after gradient of all layers are computed, but the peak mem in backward should be further reduced.

LeoXinhaoLee avatar Sep 09 '24 21:09 LeoXinhaoLee

FSDP is already doing that :) How parameters/gradients are grouped together is determined by how you call the API on modules. The common approach is to group each transformer block together. After one transformer block's gradients have been computed, we reduce-scatter it, overlapping with the next transformer block's gradient computation.

awgu avatar Sep 09 '24 21:09 awgu

This is very cool. Is there any example (maybe in Torchtitan or some other tutorial) that demonstrates using FSDP with this strategy? Or is the grouping of blocks and when to reduce-scatter figured out by compiler automatically?

LeoXinhaoLee avatar Sep 09 '24 21:09 LeoXinhaoLee

no compiler for eager mode can take a look at here: https://github.com/pytorch/torchtitan/blob/1923ce4db3018a69d2463a6efd7e1ae44cb02ec6/torchtitan/parallelisms/parallelize_llama.py#L289

awgu avatar Sep 09 '24 22:09 awgu

Will look into it, thank you very much for your help!

LeoXinhaoLee avatar Sep 09 '24 22:09 LeoXinhaoLee