TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Bug in FP8 buffer update causing training instabilities

Open Marks101 opened this issue 7 months ago • 1 comments

Hello team,

we have been debugging large scale training instabilities with FP8 and noticed that these started when updating from transfomer-engine v1.2.1 to v1.7. Taking a closer look at the trainings, it occured to me that the first iteration shows a loss that is larger than in trainings with the old version or in trainings with BF16. I was able to reproduce this with this minimal example:

import torch
import transformer_engine as te

torch.manual_seed(12345)
model = te.pytorch.Linear(768, 768)
inp = torch.rand((1024, 768), device="cuda")
out_ref = model(inp)

for micro_batch in range(5):
    with te.pytorch.fp8_autocast():
        out_fp8 = model(inp, is_first_microbatch=(micro_batch == 0))
    max_diff = torch.max(torch.abs(out_fp8 - out_ref)).item()
    print(f"{micro_batch=}: {max_diff=}")

Executing this piece of code with version v1.2.1 gives me:

micro_batch=0: max_diff=0.06777709722518921
micro_batch=1: max_diff=0.07292889058589935
micro_batch=2: max_diff=0.07292889058589935
micro_batch=3: max_diff=0.07292889058589935
micro_batch=4: max_diff=0.07292889058589935

In comparison to version 1.7:

micro_batch=0: max_diff=0.06777709722518921
micro_batch=1: max_diff=1.7055679559707642
micro_batch=2: max_diff=1.7055679559707642
micro_batch=3: max_diff=1.7055679559707642
micro_batch=4: max_diff=1.7055679559707642

After the first microbatch, the Linear produces a wrong result for version 1.7. Could you please try to reproduce this?

This is specifically connected to the case that is_first_microbatch is used. The same bug applies to LayerNormLinear and thus to MultiheadAttention ... . We bisected this and came to the conclusion that it started with #575 and got (by coincidence?) fixed with the refactorings in #820 and seems to be connected to the update of the FP8 buffers. I am not 100% sure, but for me it seems that In a training information from old iterations is used and this can cause instabilities. Overall, the current version v1.8 is not affected. Still, if you are able to reproduce this and considering that this is a "silent" bug that caused heavy instabilities on our side, it might be worth to add this to the "known issues" section.

Please let me know if you need more information from our side.

More details on our setup: DGX H100, CUDA 12.2, Torch 2.3.0

Marks101 avatar Jul 26 '24 11:07 Marks101