Megatron-LM
Megatron-LM copied to clipboard
[BUG] Loss difference when training with FP8 vs. BF16 MoE
Describe the bug
When enabling FP8 mixed precision during training of a Mixtral model (SequentialMLP
expert layer), we are observing that training and validation loss differs more than expected.
To Reproduce
Start with examples/mixtral/train_mixtral_8x7b_distributed.sh
.
- Disable
--moe-grouped-gemm
. - Pass
--fp8-format hybrid --fp8-amax-compute-algo max --fp8-amax-history-len 1024
.
Using tokenizer.model
from https://huggingface.co/mistralai/Mixtral-8x7B-v0.1.
Expected behavior Training and validation loss across BF16 and FP8 MoE should be approximately the same.
Stack trace/logs
# BF16
3: [2024-09-20 19:40:51] iteration 1/ 100 | consumed samples: 256 | elapsed time per iteration (ms): 46555.7 | throughput per GPU (TFLOP/s/GPU): 58.4 | learning rate: 2.000000E-07 | global batch size: 256 | lm loss: 1.037833E+01 | loss scale: 1.0 | grad norm: 1.664 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: [2024-09-20 19:41:04] iteration 2/ 100 | consumed samples: 512 | elapsed time per iteration (ms): 13477.5 | throughput per GPU (TFLOP/s/GPU): 201.6 | learning rate: 4.000000E-07 | global batch size: 256 | lm loss: 1.037832E+01 | loss scale: 1.0 | grad norm: 1.712 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: [2024-09-20 19:41:18] iteration 3/ 100 | consumed samples: 768 | elapsed time per iteration (ms): 13434
...
3: [2024-09-20 20:04:41] iteration 98/ 100 | consumed samples: 25088 | elapsed time per iteration (ms): 15304.1 | throughput per GPU (TFLOP/s/GPU): 177.6 | learning rate: 1.960000E-05 | global batch size: 256 | lm loss: 7.280738E+00 | loss scale: 1.0 | grad norm: 0.965 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: [2024-09-20 20:04:57] iteration 99/ 100 | consumed samples: 25344 | elapsed time per iteration (ms): 15364.6 | throughput per GPU (TFLOP/s/GPU): 176.9 | learning rate: 1.980000E-05 | global batch size: 256 | lm loss: 7.251996E+00 | loss scale: 1.0 | grad norm: 0.690 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: [2024-09-20 20:05:12] iteration 100/ 100 | consumed samples: 25600 | elapsed time per iteration (ms): 15208.5 | throughput per GPU (TFLOP/s/GPU): 178.7 | learning rate: 2.000000E-05 | global batch size: 256 | lm loss: 7.260425E+00 | loss scale: 1.0 | grad norm: 0.632 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: validation loss at iteration 100 on validation set | lm loss value: 7.251174E+00 | lm loss PPL: 1.409760E+03 |
3: validation loss at iteration 100 on test set | lm loss value: 7.255818E+00 | lm loss PPL: 1.416322E+03 |
# FP08
3: [2024-09-20 19:41:08] iteration 1/ 100 | consumed samples: 256 | elapsed time per iteration (ms): 62104.1 | throughput per GPU (TFLOP/s/GPU): 43.8 | learning rate: 2.000000E-07 | global batch size: 256 | lm loss: 1.037847E+01 | loss scale: 1.0 | grad norm: 0.534 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: [2024-09-20 19:41:21] iteration 2/ 100 | consumed samples: 512 | elapsed time per iteration (ms): 13276.7 | throughput per GPU (TFLOP/s/GPU): 204.7 | learning rate: 4.000000E-07 | global batch size: 256 | lm loss: 1.037833E+01 | loss scale: 1.0 | grad norm: 0.571 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: [2024-09-20 19:41:34] iteration 3/ 100 | consumed samples: 768 | elapsed time per iteration (ms): 13319.2 | throughput per GPU (TFLOP/s/GPU): 204.0 | learning rate: 6.000000E-07 | global batch size: 256 | lm loss: 1.037832E+01 | loss scale: 1.0 | grad norm: 0.568 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
...
3: [2024-09-20 20:03:04] iteration 98/ 100 | consumed samples: 25088 | elapsed time per iteration (ms): 13616.6 | throughput per GPU (TFLOP/s/GPU): 199.6 | learning rate: 1.960000E-05 | global batch size: 256 | lm loss: 7.739647E+00 | loss scale: 1.0 | grad norm: 4.661 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: [2024-09-20 20:03:18] iteration 99/ 100 | consumed samples: 25344 | elapsed time per iteration (ms): 13650.1 | throughput per GPU (TFLOP/s/GPU): 199.1 | learning rate: 1.980000E-05 | global batch size: 256 | lm loss: 7.716366E+00 | loss scale: 1.0 | grad norm: 4.697 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: [2024-09-20 20:03:32] iteration 100/ 100 | consumed samples: 25600 | elapsed time per iteration (ms): 13625.0 | throughput per GPU (TFLOP/s/GPU): 199.5 | learning rate: 2.000000E-05 | global batch size: 256 | lm loss: 7.721111E+00 | loss scale: 1.0 | grad norm: 4.632 | num zeros: 0.0 | number of skipped iterations: 0 | number of nan iterations: 0 |
3: validation loss at iteration 100 on validation set | lm loss value: 7.706149E+00 | lm loss PPL: 2.221969E+03 |
3: validation loss at iteration 100 on test set | lm loss value: 7.708397E+00 | lm loss PPL: 2.226969E+03 |
moe_megatron_bf16_22455.log moe_megatron_fp8_22454.log
Environment (please complete the following information):
- Megatron-LM commit ID: 835af44a3
- Megatron-core version: 0.9.0rc0
- PyTorch version: 2.3.0a0+40ec155e58.nv24.3
- CUDA version: 12.4
- NCCL version: 2.19.4
- TransformerEngine version: 1.8.0.dev0+7d576ed
- Base Container: nvcr.io/nvidia/nemo:24.07
Additional context
We also experimented with enabling FP8 with the TEGroupedMLP
module (padding inputs for FP8), and see some loss differences there as well.