TransformerEngine
TransformerEngine copied to clipboard
Update FP8 scale-inverse in kernels with FP8 output
Description
We currently treat the FP8 scale-inverse (the dequantization scaling factor) as part of the FP8 recipe, along with the FP8 scale (the quantization scaling factor) and the absmax history. However, this is uncomfortable because any change to the FP8 recipe will invalidate the corresponding FP8 data. We work around this by creating copies of the scale-invs whenever there might be a recipe update, e.g. in between the forward and backward passes of the linear layer: https://github.com/NVIDIA/TransformerEngine/blob/6717554f11f9b8bd79f917560e525d538c95b3bc/transformer_engine/pytorch/module/linear.py#L318
This adds non-trivial CPU overhead (I estimate ~20% for the PyTorch linear layer forward pass on an L40).
A better approach is to treat the scale-inv as part of the FP8 data, something that should be output along with the FP8 bits and should never change independently of the FP8 bits. The FP8 recipe tells us how we want to cast into FP8, while the scale-inv tells us how to convert back to higher precision. Note that this generalizes nicely to block-scaling schemes, where the scale-inv tensor may be large and must be packaged with the data during communication.
This PR makes initial work toward this scheme by including scale-inv updates in most of the kernels with FP8 output: casting, activations, LayerNorm, RMSNorm. It doesn't seem that cuBLAS supports this, so I've added a small kernel that is launched after FP8 GEMMs. I have not attempted to propagate this change into Userbuffers or attention. I've also updated the PyTorch Linear and LayerNormLinear modules to avoid maintaining extra copies of the scale-inv and I see a 1.12x speedup in the Linear forward pass.
I'm a little apprehensive since this is technically a breaking change. Every time we generate FP8 values we will overwrite the FP8 recipe scale-inv. I have a hard time imagining why we would ever use a stale FP8 scale-inv though if the FP8 data has already been overwritten.
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [x] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [x] Code refractor
Changes
- Update FP8 scale-inverse in cast-transpose kernels
- Update FP8 scale-inverse in cast and activation kernels
- Update FP8 scale-inverse in LayerNorm and RMSNorm kernels
- Update FP8 scale-inverse after FP8 GEMMs
- Avoid unnecessary FP8 scale-inverse copies in PyTorch
Linearmodule - Avoid unnecessary FP8 scale-inverse copies in PyTorch
LayerNormLinearmodule
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [ ] I have commented my code, particularly in hard-to-understand areas
- [ ] I have made corresponding changes to the documentation
- [ ] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
/te-ci
/te-ci
/te-ci
/te-ci
/te-ci
/te-ci