torchtitan icon indicating copy to clipboard operation
torchtitan copied to clipboard

Gradient Scaling With Pipeline Parallelism

Open windsornguyen opened this issue 10 months ago • 2 comments

The idiomatic way to perform gradient scaling is something like this:

preds = model(inputs)
loss = loss_fn(preds, targets)
scaler.scale(loss).backward()

Given that the current PyTorch PP API handles the backward pass internally, I find it difficult to do gradient scaling under a PP regime.

if is_first_stage:
    pp_schedule.step(inputs)                        # bwd performed internally
elif is_last_stage:
    losses = []
    pp_schedule.step(target=targets, losses=losses) # bwd performed internally
else:
    pp_schedule.step()                              # bwd performed internally

loss = (
    torch.mean(torch.stack(losses)).to(device)
    if is_last_stage
    else torch.tensor([-1.0], device=device)
)

# scaler.scale(loss).backward() <-- !? backward pass has already been performed

Is there currently a good way to do gradient scaling with Pipeline Parallelism? And if not, will the Pipeline Parallelism API support gradient scaling in the near-term future?

windsornguyen avatar Jan 24 '25 12:01 windsornguyen

scaler.scale(loss).backward()

In TorchTian, we have gradient scaling with backward ahead, could I have more information on scaler.scale(loss).backward() here?

mori360 avatar Jan 24 '25 19:01 mori360

I think there are a few options. 1- there is a new grad-scale feature inside pipelining. You can enable 'scale_grads=True', assuming you just want to scale by the num_microbatches, it should 'just work'. (https://pytorch.org/docs/main/distributed.pipelining.html#torch.distributed.pipelining.schedules.ScheduleGPipe) 1a- it would not be hard to expose a way to pass a custom scale value or 'scaler' object into Pipelining as an extension of this, but it isn't implemented today. Would that be of interest to you? You can also achieve it (hackily) today by putting scaling inside your 'loss_fn' and passing that scale_loss_fn into PipelineSchedule.

2- instead of scaling the loss, you could scale the grads after they are accumulated but before stepping the optimizer. (This would be done manually, by just iterating the parameters of the pipeline stage submod, and performing scaling.

wconstab avatar Jan 24 '25 20:01 wconstab