Megatron-LM icon indicating copy to clipboard operation
Megatron-LM copied to clipboard

[BUG] Loss difference when training with FP8 vs. BF16 MoE

Open viclzhu opened this issue 5 months ago • 6 comments

Describe the bug When enabling FP8 mixed precision during training of a Mixtral model (SequentialMLP expert layer), we are observing that training and validation loss differs more than expected.

To Reproduce Start with examples/mixtral/train_mixtral_8x7b_distributed.sh.

  • Disable --moe-grouped-gemm.
  • Pass --fp8-format hybrid --fp8-amax-compute-algo max --fp8-amax-history-len 1024.

Using tokenizer.model from https://huggingface.co/mistralai/Mixtral-8x7B-v0.1.

Expected behavior Training and validation loss across BF16 and FP8 MoE should be approximately the same.

Stack trace/logs

# BF16
3:  [2024-09-20 19:40:51] iteration        1/     100 | consumed samples:          256 | elapsed time per iteration (ms): 46555.7 | throughput per GPU (TFLOP/s/GPU): 58.4 | learning rate: 2.000000E-07 | global batch size:   256 | lm loss: 1.037833E+01 | loss scale: 1.0 | grad norm: 1.664 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 19:41:04] iteration        2/     100 | consumed samples:          512 | elapsed time per iteration (ms): 13477.5 | throughput per GPU (TFLOP/s/GPU): 201.6 | learning rate: 4.000000E-07 | global batch size:   256 | lm loss: 1.037832E+01 | loss scale: 1.0 | grad norm: 1.712 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 19:41:18] iteration        3/     100 | consumed samples:          768 | elapsed time per iteration (ms): 13434
...
3:  [2024-09-20 20:04:41] iteration       98/     100 | consumed samples:        25088 | elapsed time per iteration (ms): 15304.1 | throughput per GPU (TFLOP/s/GPU): 177.6 | learning rate: 1.960000E-05 | global batch size:   256 | lm loss: 7.280738E+00 | loss scale: 1.0 | grad norm: 0.965 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 20:04:57] iteration       99/     100 | consumed samples:        25344 | elapsed time per iteration (ms): 15364.6 | throughput per GPU (TFLOP/s/GPU): 176.9 | learning rate: 1.980000E-05 | global batch size:   256 | lm loss: 7.251996E+00 | loss scale: 1.0 | grad norm: 0.690 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 20:05:12] iteration      100/     100 | consumed samples:        25600 | elapsed time per iteration (ms): 15208.5 | throughput per GPU (TFLOP/s/GPU): 178.7 | learning rate: 2.000000E-05 | global batch size:   256 | lm loss: 7.260425E+00 | loss scale: 1.0 | grad norm: 0.632 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  validation loss at iteration 100 on validation set | lm loss value: 7.251174E+00 | lm loss PPL: 1.409760E+03 |
3:  validation loss at iteration 100 on test set | lm loss value: 7.255818E+00 | lm loss PPL: 1.416322E+03 |
# FP08
3:  [2024-09-20 19:41:08] iteration        1/     100 | consumed samples:          256 | elapsed time per iteration (ms): 62104.1 | throughput per GPU (TFLOP/s/GPU): 43.8 | learning rate: 2.000000E-07 | global batch size:   256 | lm loss: 1.037847E+01 | loss scale: 1.0 | grad norm: 0.534 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 19:41:21] iteration        2/     100 | consumed samples:          512 | elapsed time per iteration (ms): 13276.7 | throughput per GPU (TFLOP/s/GPU): 204.7 | learning rate: 4.000000E-07 | global batch size:   256 | lm loss: 1.037833E+01 | loss scale: 1.0 | grad norm: 0.571 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 19:41:34] iteration        3/     100 | consumed samples:          768 | elapsed time per iteration (ms): 13319.2 | throughput per GPU (TFLOP/s/GPU): 204.0 | learning rate: 6.000000E-07 | global batch size:   256 | lm loss: 1.037832E+01 | loss scale: 1.0 | grad norm: 0.568 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
...
3:  [2024-09-20 20:03:04] iteration       98/     100 | consumed samples:        25088 | elapsed time per iteration (ms): 13616.6 | throughput per GPU (TFLOP/s/GPU): 199.6 | learning rate: 1.960000E-05 | global batch size:   256 | lm loss: 7.739647E+00 | loss scale: 1.0 | grad norm: 4.661 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 20:03:18] iteration       99/     100 | consumed samples:        25344 | elapsed time per iteration (ms): 13650.1 | throughput per GPU (TFLOP/s/GPU): 199.1 | learning rate: 1.980000E-05 | global batch size:   256 | lm loss: 7.716366E+00 | loss scale: 1.0 | grad norm: 4.697 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  [2024-09-20 20:03:32] iteration      100/     100 | consumed samples:        25600 | elapsed time per iteration (ms): 13625.0 | throughput per GPU (TFLOP/s/GPU): 199.5 | learning rate: 2.000000E-05 | global batch size:   256 | lm loss: 7.721111E+00 | loss scale: 1.0 | grad norm: 4.632 | num zeros: 0.0 | number of skipped iterations:   0 | number of nan iterations:   0 |
3:  validation loss at iteration 100 on validation set | lm loss value: 7.706149E+00 | lm loss PPL: 2.221969E+03 |
3:  validation loss at iteration 100 on test set | lm loss value: 7.708397E+00 | lm loss PPL: 2.226969E+03 |

moe_megatron_bf16_22455.log moe_megatron_fp8_22454.log

Environment (please complete the following information):

  • Megatron-LM commit ID: 835af44a3
  • Megatron-core version: 0.9.0rc0
  • PyTorch version: 2.3.0a0+40ec155e58.nv24.3
  • CUDA version: 12.4
  • NCCL version: 2.19.4
  • TransformerEngine version: 1.8.0.dev0+7d576ed
  • Base Container: nvcr.io/nvidia/nemo:24.07

Additional context We also experimented with enabling FP8 with the TEGroupedMLP module (padding inputs for FP8), and see some loss differences there as well.

viclzhu avatar Sep 20 '24 22:09 viclzhu