TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

[PyTorch] Normalization ops

Open timmoon10 opened this issue 1 year ago • 10 comments

Description

This PR extends the operation-based API (see https://github.com/NVIDIA/TransformerEngine/pull/707) with LayerNorm, RMSNorm, and FP8 cast operations.

Compare with the existing module-based API:

# Module-based API
module1 = te.LayerNormLinear(...)

# Operation-based API
module2 = te.ops.Sequential(
    te.ops.LayerNorm(...),
    te.ops.Linear(...),
)

Type of change

  • [ ] Documentation change (change only to the documentation, either a fix or a new content)
  • [ ] Bug fix (non-breaking change which fixes an issue)
  • [x] New feature (non-breaking change which adds functionality)
  • [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • [ ] Infra/Build change
  • [ ] Code refractor

Changes

Please list the changes introduced in this PR:

  • LayerNorm operation
  • FP8 cast operation
  • RMSNorm operation

Checklist:

  • [x] I have read and followed the contributing guidelines
  • [x] The functionality is complete
  • [x] I have commented my code, particularly in hard-to-understand areas
  • [x] I have made corresponding changes to the documentation
  • [x] My changes generate no new warnings
  • [x] I have added tests that prove my fix is effective or that my feature works
  • [x] New and existing unit tests pass locally with my changes

timmoon10 avatar Jul 22 '24 22:07 timmoon10

/te-ci pytorch

timmoon10 avatar Jul 22 '24 22:07 timmoon10

/te-ci pytorch

timmoon10 avatar Jul 22 '24 22:07 timmoon10

/te-ci pytorch

timmoon10 avatar Jul 30 '24 00:07 timmoon10

/te-ci pytorch

timmoon10 avatar Jul 30 '24 22:07 timmoon10

/te-ci pytorch

timmoon10 avatar Aug 12 '24 23:08 timmoon10

/te-ci pytorch

timmoon10 avatar Sep 03 '24 18:09 timmoon10

/te-ci pytorch

timmoon10 avatar Sep 11 '24 21:09 timmoon10

/te-ci pytorch

timmoon10 avatar Sep 19 '24 19:09 timmoon10

/te-ci pytorch

Edit: te-ci/docs failure disappears when job is rerun.

timmoon10 avatar Sep 24 '24 18:09 timmoon10

/te-ci pytorch

timmoon10 avatar Oct 01 '24 22:10 timmoon10

/te-ci pytorch

timmoon10 avatar Oct 10 '24 00:10 timmoon10

/te-ci pytorch

timmoon10 avatar Oct 18 '24 23:10 timmoon10

/te-ci pytorch

timmoon10 avatar Nov 05 '24 17:11 timmoon10

Merging with approval from @ptrendx and @ksivaman.

timmoon10 avatar Nov 05 '24 21:11 timmoon10

Hi one question regarding this new layernorm implementation, my understanding is that this new implementation can support multi-dimension layernorm weight while previous implemenation can only support one-dimension weight.

So I have N different 1-D tensors and previously I have to initiate N different layernorm and apply them separately. But with this new implementation, then we can apply one 2-dimensional layernorm to N stacked different tensors with shape (N, dim)? is my understanding correct?

binxuan avatar Dec 19 '24 20:12 binxuan

@binxuan This implementation matches torch.nn.LayerNorm:

x_2d = x.reshape(-1, prod(normalized_shape))
y_2d = layer_norm_2d(x_2d, weight.reshape(-1), bias.reshape(-1))
y = y_2d.reshape(x.size())

timmoon10 avatar Feb 28 '25 19:02 timmoon10