TransformerEngine icon indicating copy to clipboard operation
TransformerEngine copied to clipboard

RMSNorm precision different from HF implementation

Open void-main opened this issue 1 year ago • 5 comments

We noticed there's a tiny implementation difference that makes transformer_engine.pytorch.module.rmsnorm and also TELayerNormColumnParallelLinear generate results from HF implementation.

And the tiny difference is when the hidden_states are converted back to bfloat16. Here's the gap: CleanShot 2024-08-23 at 22 40 51@2x

  • the red line is native HF implementation, and converts hidden_states to bfloat16 before multiply weight, and TENorm's result is different form this implementation
  • the green line implementation matches the TENorm's implementation, and converts hidden_states to bfloat16 after multiply weights.

We wonder if TE could provide an other to match the HF's implementation, which converts hidden_states to bfloat16 before multiply the weights. Thanks.

How to reproduce

Version: transformer-engine 1.7.0+4e7caa1

Code to reproduce:

import unittest
import torch
import torch.nn as nn
from transformer_engine.pytorch.module.rmsnorm import RMSNorm as TELayerNorm
from copy_from_hf import HFRMSNorm

class TestLayerNormComparison(unittest.TestCase):
    def setUp(self):
        self.hidden_size = 4096
        self.batch_size = 1
        self.seq_length = 1024
        self.eps = 1e-5

        self.shared_weight = nn.Parameter(torch.randn(self.hidden_size, dtype=torch.bfloat16))

        self.te_layernorm = TELayerNorm(self.hidden_size, eps=self.eps, zero_centered_gamma=False).to(torch.bfloat16)
        self.hf_rmsnorm = HFRMSNorm(self.hidden_size, eps=self.eps).to(torch.bfloat16)

        self.te_layernorm.weight = self.shared_weight
        self.hf_rmsnorm.weight = self.shared_weight

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.te_layernorm.to(self.device)
        self.hf_rmsnorm.to(self.device)

    def test_layernorm_comparison(self):
        input_tensor = torch.randn(self.batch_size, self.seq_length, self.hidden_size,
                                   dtype=torch.bfloat16, device=self.device)

        with torch.no_grad():
            te_output = self.te_layernorm(input_tensor)
            hf_output = self.hf_rmsnorm(input_tensor)

        assert torch.allclose(te_output, hf_output, atol=1e-2)

if __name__ == '__main__':
    unittest.main()

First define HFRMSNorm with native implementation:

import torch
from torch import nn

class HFRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6, config=None):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

The assertion should fail when we run the code with this implementation.

Now, let's change the last line from return self.weight * hidden_states.to(input_dtype) to return (self.weight.to(torch.float32) * hidden_states).to(input_dtype), the assertion should pass.

void-main avatar Aug 23 '24 14:08 void-main

That is correct, both RMSNorm and LayerNorm in TE perform all internal computation in FP32 (and so e.g. TE LayerNorm is equivalent to

x = x.to(torch.float32)
y = nn.LayerNorm(x)
y = y.to(torch.bfloat16)

The reason for that is to preserve precision of the computation, especially since RMSNorm/LayerNorm weights are typically close to 1. This is especially important and visible with zero_centered_gamma option, which initializes the weight to 0 and adds 1 to it inside the normalization operation itself. Since the floating point numbers are the most precise around 0, and because bfloat16 does not have many mantissa bits, adding 1 to it in precision other than float32 results in losing most of that precision - see e.g. this example:

>>> import torch                                                                                                                                                                                                   
>>> a = torch.Tensor([0.003]).to(torch.bfloat16)                                                                                                                                                                   
>>> a                                                                                                                                                                                                              
tensor([0.0030], dtype=torch.bfloat16)                                                                                                                                                                             
>>> a + 1                                                                                                                                                                                                          
tensor([1.], dtype=torch.bfloat16)

Based on this, I would argue that it is actually HF implementation that is wrong here.

ptrendx avatar Aug 23 '24 16:08 ptrendx

@ptrendx Thanks for your reply. I totally agree that we should use float32 to do all the calculations, in theory.

However, we're not training from scratch. We're continuous training open source models like Llama3 and Qwen2 with Megatron-LM, and if we compare logits generated by Megatron-LM and HF transformers, the RMSNorm implementation difference will cause the logits to be very different. (80% of the hidden_states elements are different in numerics more than 0.01)

That's why I believe we should at least provide an option to align the RMSNorm with HF transformers?

void-main avatar Aug 26 '24 05:08 void-main

Yeah, I figured that's a probable reason for this ask. Could you open an issue in HF transformers repo as well then? It would be interesting to hear their opinion on the topic and also raise their awareness to, hopefully, align the implementations to the right precisions with new models going forward.

I need to think how to expose that option. In the meantime - if you wanted to change TE implementation yourself to do the multiplication in the lower precision you would need to change https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh#L108-L109 and https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/common/rmsnorm/rmsnorm_fwd_kernels.cuh#L248-L249 to use Ktraits::weight_t rather than compute_t type.

ptrendx avatar Aug 26 '24 18:08 ptrendx

Great, thank you @ptrendx ! I'll try to change the code myself.

Besides, here's the issue on HF: https://github.com/huggingface/transformers/issues/33133

void-main avatar Aug 27 '24 01:08 void-main

We just stumbled upon this issue and compared the implementation of the RMSNorm between TransformerEngine and TensorRT-LLM. It looks like TensorRT-LLM does the weight multiplication in lower precision, consistent with the HF transformers implementation. This likely means that a model trained with TransformerEngine will produce (at least slightly) different outputs when inferenced with TensorRT-LLM.

I agree with @ptrendx that performing the operation in higher precision sounds sensible but I think it would be useful to have the option to align implementations across Nvidia's stack.

fjosw avatar Aug 29 '24 08:08 fjosw