Gradient Scaling With Pipeline Parallelism
The idiomatic way to perform gradient scaling is something like this:
preds = model(inputs)
loss = loss_fn(preds, targets)
scaler.scale(loss).backward()
Given that the current PyTorch PP API handles the backward pass internally, I find it difficult to do gradient scaling under a PP regime.
if is_first_stage:
pp_schedule.step(inputs) # bwd performed internally
elif is_last_stage:
losses = []
pp_schedule.step(target=targets, losses=losses) # bwd performed internally
else:
pp_schedule.step() # bwd performed internally
loss = (
torch.mean(torch.stack(losses)).to(device)
if is_last_stage
else torch.tensor([-1.0], device=device)
)
# scaler.scale(loss).backward() <-- !? backward pass has already been performed
Is there currently a good way to do gradient scaling with Pipeline Parallelism? And if not, will the Pipeline Parallelism API support gradient scaling in the near-term future?
scaler.scale(loss).backward()
In TorchTian, we have gradient scaling with backward ahead, could I have more information on scaler.scale(loss).backward() here?
I think there are a few options. 1- there is a new grad-scale feature inside pipelining. You can enable 'scale_grads=True', assuming you just want to scale by the num_microbatches, it should 'just work'. (https://pytorch.org/docs/main/distributed.pipelining.html#torch.distributed.pipelining.schedules.ScheduleGPipe) 1a- it would not be hard to expose a way to pass a custom scale value or 'scaler' object into Pipelining as an extension of this, but it isn't implemented today. Would that be of interest to you? You can also achieve it (hackily) today by putting scaling inside your 'loss_fn' and passing that scale_loss_fn into PipelineSchedule.
2- instead of scaling the loss, you could scale the grads after they are accumulated but before stepping the optimizer. (This would be done manually, by just iterating the parameters of the pipeline stage submod, and performing scaling.