TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

how to use FusedRMSNorm?

Open EthanChen1234 opened this issue 1 year ago • 1 comments

hi, TE is really a great job.

how to use in FusedRMSNorm in TE?

https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm.py#L329

EthanChen1234 avatar Jun 19 '24 12:06 EthanChen1234

To use RMSNorm by itself, you can simply construct a te.RMSNorm module:

import torch
import transformer_engine.pytorch as te

# TE module
layer = te.RMSNorm(128)

# Synthetic data
x = torch.randn(128, 128).cuda()

# Forward and backward pass
y = layer(x)
y.sum().backward()

If you know it's going to be followed by a linear operation, it may be worthwhile using the te.LayerNormLinear or te.LayerNormMLP modules:

layer = te.LayerNormLinear(128, 128, normalization="RMSNorm")
y = layer(x)

This allows for some kernel fusions when running with FP8, e.g. fusing the RMSNorm with an FP8 cast.

timmoon10 avatar Jun 25 '24 22:06 timmoon10