TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] FP8 and activation checkpointing causes training instabilities
Hello team,
we noticed training instabilities when combining FP8 and activation checkpointing with transformer_engine.pytorch.checkpoint
. When taking a closer look at this, we got the feeling that the FP8 scales in the backward pass are not updated properly. Here is a code snippet that is meant to show this behavior:
import torch
import transformer_engine as te
hidden_size = 768
def run_iterations(checkpoint: bool):
torch.manual_seed(12345)
model = te.pytorch.Linear(hidden_size, hidden_size)
result = []
for it in range(2):
x = torch.randn((hidden_size, hidden_size), requires_grad=True, device="cuda")
y_grad = torch.randn_like(x)
with te.pytorch.fp8_autocast():
if checkpoint:
y = te.pytorch.checkpoint(model, x)
else:
y = model(x)
y.backward(y_grad)
result.append(dict(it=it, y=y, x_grad=x.grad, scaling_bwd=model.fp8_meta["scaling_bwd"].scale[0].item()))
return result
result_ref = run_iterations(checkpoint=False)
result_cp = run_iterations(checkpoint=True)
for r_ref, r_cp in zip(result_ref, result_cp):
max_diff_y = torch.max(torch.abs(r_ref["y"] - r_cp["y"])).item()
max_diff_x_grad = torch.max(torch.abs(r_ref["x_grad"] - r_cp["x_grad"])).item()
print(f"it={r_ref['it']}: {max_diff_y=}, {max_diff_x_grad=:.02}")
print(f" scale bwd: ref={r_ref['scaling_bwd']:.2}, cp={r_cp['scaling_bwd']:.2}")
The snippet creates a linear layer and runs two forward passes on it with FP8. This is excuted once with and the other time without activation checkpointing. The results are recorded and compared at the end.
When executing this code I get the following output:
it=0: max_diff_y=0.0, max_diff_x_grad=0.0
scale bwd: ref=1.1e+04, cp=1.0
it=1: max_diff_y=0.0, max_diff_x_grad=0.21
scale bwd: ref=1.1e+04, cp=1.0
We can see that the forward passes give the same results for the run with and without checkpointing. But the gradient of x differs in the second iteration. Additionally we see that the scaling fp8_meta["scaling_bwd"].scale[0]
is not updated for the run with activation checkpointing. My guess is that this is due to the fact that reduce_and_update_bwd_fp8_tensors
is not activated because FP8GlobalStateManager.is_first_fp8_module() == False
in the second forward pass for activation recompute which is outside of the fp8_autocast
context, see here.
Other modules like TransformerLayer have the same issue. Could you please have a look and check if you can reproduce these findings? Are we using checkpointing improperly?
Thanks