TransformerEngine
TransformerEngine copied to clipboard
RMSNorm precision different from HF implementation
We noticed there's a tiny implementation difference that makes transformer_engine.pytorch.module.rmsnorm and also TELayerNormColumnParallelLinear generate results from HF implementation.
And the tiny difference is when the hidden_states are converted back to bfloat16. Here's the gap:
- the red line is native HF implementation, and converts hidden_states to bfloat16 before multiply weight, and TENorm's result is different form this implementation
- the green line implementation matches the TENorm's implementation, and converts hidden_states to bfloat16 after multiply weights.
We wonder if TE could provide an other to match the HF's implementation, which converts hidden_states to bfloat16 before multiply the weights. Thanks.
How to reproduce
Version: transformer-engine 1.7.0+4e7caa1
Code to reproduce:
import unittest
import torch
import torch.nn as nn
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as TELayerNorm
from copy_from_hf import HFRMSNorm
class TestLayerNormComparison(unittest.TestCase):
def setUp(self):
self.hidden_size = 4096
self.batch_size = 1
self.seq_length = 1024
self.eps = 1e-5
self.shared_weight = nn.Parameter(torch.randn(self.hidden_size, dtype=torch.bfloat16))
self.te_layernorm = TELayerNorm(self.hidden_size, eps=self.eps, zero_centered_gamma=False).to(torch.bfloat16)
self.hf_rmsnorm = HFRMSNorm(self.hidden_size, eps=self.eps).to(torch.bfloat16)
self.te_layernorm.weight = self.shared_weight
self.hf_rmsnorm.weight = self.shared_weight
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.te_layernorm.to(self.device)
self.hf_rmsnorm.to(self.device)
def test_layernorm_comparison(self):
input_tensor = torch.randn(self.batch_size, self.seq_length, self.hidden_size,
dtype=torch.bfloat16, device=self.device)
with torch.no_grad():
te_output = self.te_layernorm(input_tensor)
hf_output = self.hf_rmsnorm(input_tensor)
assert torch.allclose(te_output, hf_output, atol=1e-2)
if __name__ == '__main__':
unittest.main()
First define HFRMSNorm with native implementation:
import torch
from torch import nn
class HFRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6, config=None):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
The assertion should fail when we run the code with this implementation.
Now, let's change the last line from return self.weight * hidden_states.to(input_dtype) to return (self.weight.to(torch.float32) * hidden_states).to(input_dtype), the assertion should pass.
That is correct, both RMSNorm and LayerNorm in TE perform all internal computation in FP32 (and so e.g. TE LayerNorm is equivalent to
x = x.to(torch.float32)
y = nn.LayerNorm(x)
y = y.to(torch.bfloat16)
The reason for that is to preserve precision of the computation, especially since RMSNorm/LayerNorm weights are typically close to 1.
This is especially important and visible with zero_centered_gamma option, which initializes the weight to 0 and adds 1 to it inside the normalization operation itself. Since the floating point numbers are the most precise around 0, and because bfloat16 does not have many mantissa bits, adding 1 to it in precision other than float32 results in losing most of that precision - see e.g. this example:
>>> import torch
>>> a = torch.Tensor([0.003]).to(torch.bfloat16)
>>> a
tensor([0.0030], dtype=torch.bfloat16)
>>> a + 1
tensor([1.], dtype=torch.bfloat16)
Based on this, I would argue that it is actually HF implementation that is wrong here.
@ptrendx Thanks for your reply. I totally agree that we should use float32 to do all the calculations, in theory.
However, we're not training from scratch. We're continuous training open source models like Llama3 and Qwen2 with Megatron-LM, and if we compare logits generated by Megatron-LM and HF transformers, the RMSNorm implementation difference will cause the logits to be very different. (80% of the hidden_states elements are different in numerics more than 0.01)
That's why I believe we should at least provide an option to align the RMSNorm with HF transformers?
Yeah, I figured that's a probable reason for this ask. Could you open an issue in HF transformers repo as well then? It would be interesting to hear their opinion on the topic and also raise their awareness to, hopefully, align the implementations to the right precisions with new models going forward.
I need to think how to expose that option. In the meantime - if you wanted to change TE implementation yourself to do the multiplication in the lower precision you would need to change
https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh#L108-L109 and https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh#L248-L249 to use Ktraits::weight_t rather than compute_t type.
Great, thank you @ptrendx ! I'll try to change the code myself.
Besides, here's the issue on HF: https://github.com/huggingface/transformers/issues/33133
We just stumbled upon this issue and compared the implementation of the RMSNorm between TransformerEngine and TensorRT-LLM. It looks like TensorRT-LLM does the weight multiplication in lower precision, consistent with the HF transformers implementation. This likely means that a model trained with TransformerEngine will produce (at least slightly) different outputs when inferenced with TensorRT-LLM.
I agree with @ptrendx that performing the operation in higher precision sounds sensible but I think it would be useful to have the option to align implementations across Nvidia's stack.