[Pytorch] change fused cross entropy backward grad to fp32 and reduce one read/…
Description
The fused cross entropy kernel in Transformer Engine uses 16-bit floating point (BF16) for the backward pass when the input is in BF16, whereas Megatron's VocabParallelCrossEntropy performs its computations in FP32. This discrepancy may lead to divergence in some cases.
This PR also reduces one read of logits, which improves the performance by up to 1.25x.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [x] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refactoring
Changes
Please list the changes introduced in this PR:
- Changed the fused cross entropy backward gradient computation to fp32 for consistency with Megatron's VocabParallelCrossEntropy.
- Optimized the computation logic to reduce one read/write operation of the logits.
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
@sanandaraj5597 @timmoon10 Could you please review? The previous BF16 backward may lead to divergence in some cases (reported by several customers).
@RandMist You need to sign-off your commits (git commit -s). See this.
@RandMist @yaox12 Hello~ Our experiments found that after applying this change, the output can sometimes become inf. This subsequently leads to NaN gradients for the corresponding token during the RMS backward pass. Furthermore, this is a non-deterministic issue that cannot be reliably reproduced on a specific input.
The goal of this kernel was to avoid saving the input for backward. The goal is to write the gradients on the input tensor itself to reduce the peak memory usage. This change defeats that purpose and break our upstream workloads using this kernel.
I understand you have improved performance but it comes at the cost of extra memory usage which we don't want.
@RandMist @yaox12 Hello~ Our experiments found that after applying this change, the output can sometimes become inf. This subsequently leads to NaN gradients for the corresponding token during the RMS backward pass. Furthermore, this is a non-deterministic issue that cannot be reliably reproduced on a specific input.
Hi, I will try to reproduce this phenomenon, or you can provide a script/data for reproduction.
The goal of this kernel was to avoid saving the input for backward. The goal is to write the gradients on the input tensor itself to reduce the peak memory usage. This change defeats that purpose and break our upstream workloads using this kernel.
I understand you have improved performance but it comes at the cost of extra memory usage which we don't want.
Hi, the main objective of this new kernel is to keep the backward propagation computations in fp32.
The original kernel stored gradients in the input, whereas the new kernel directly stores the input and additionally saves the target and m_d_X_y. The shapes of these extra tensors are n_rows × 1 and n_rows × 4, respectively. Compared to the input shape of n_rows × n_cols (where n_cols is typically >100,000), this additional memory consumption should therefore be negligible.
@RandMist @yaox12 Hello~ Our experiments found that after applying this change, the output can sometimes become inf. This subsequently leads to NaN gradients for the corresponding token during the RMS backward pass. Furthermore, this is a non-deterministic issue that cannot be reliably reproduced on a specific input.
We did not observe this phenomenon in our internal experiments. We trained the 16B MoE model for 20k steps, and the loss/grad norm of the model using the new fused kernel fully matched that of the model using mcore's original unfused loss in fp32 (the old fused kernel with bf16 backward showed an upward trend in grad norm after 15k steps).