TransformerEngine
TransformerEngine copied to clipboard
[PyTorch] Normalization ops
Description
This PR extends the operation-based API (see https://github.com/NVIDIA/TransformerEngine/pull/707) with LayerNorm, RMSNorm, and FP8 cast operations.
Compare with the existing module-based API:
# Module-based API
module1 = te.LayerNormLinear(...)
# Operation-based API
module2 = te.ops.Sequential(
te.ops.LayerNorm(...),
te.ops.Linear(...),
)
Type of change
- [ ] Documentation change (change only to the documentation, either a fix or a new content)
- [ ] Bug fix (non-breaking change which fixes an issue)
- [x] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
- [ ] Infra/Build change
- [ ] Code refractor
Changes
Please list the changes introduced in this PR:
- LayerNorm operation
- FP8 cast operation
- RMSNorm operation
Checklist:
- [x] I have read and followed the contributing guidelines
- [x] The functionality is complete
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] New and existing unit tests pass locally with my changes
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
Edit: te-ci/docs failure disappears when job is rerun.
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
/te-ci pytorch
Merging with approval from @ptrendx and @ksivaman.
Hi one question regarding this new layernorm implementation, my understanding is that this new implementation can support multi-dimension layernorm weight while previous implemenation can only support one-dimension weight.
So I have N different 1-D tensors and previously I have to initiate N different layernorm and apply them separately. But with this new implementation, then we can apply one 2-dimensional layernorm to N stacked different tensors with shape (N, dim)? is my understanding correct?
@binxuan This implementation matches torch.nn.LayerNorm:
x_2d = x.reshape(-1, prod(normalized_shape))
y_2d = layer_norm_2d(x_2d, weight.reshape(-1), bias.reshape(-1))
y = y_2d.reshape(x.size())