apex icon indicating copy to clipboard operation
apex copied to clipboard

Questions about numeric precision of FusedRMSNorm

Open yingtongxiong opened this issue 2 years ago • 5 comments

Hello, I have tested the numeric precision of FusedRMSNorm and MixFusedRMSNorm in two different version respectively. Finally, I found that the gradient of model weights can not keep the same and there are a quiet large difference in weight gradient. Therefore, can you help me to solve it or give some suggestions?

The following is my test implementation:

import copy

import torch
from torch import nn

from apex.normalization.fused_layer_norm import MixedFusedRMSNorm, FusedRMSNorm


def manual_rms_norm(input, normalized_shape, weight, eps):
    # layer norm should always be calculated in float32
    
    dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1))
    
    variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
    input = input * torch.rsqrt(variance + eps)
    if weight is None:
        return input
    # convert into half-precision if necessary
    if weight.dtype in [torch.float16, torch.bfloat16]:
        input = input.to(weight.dtype)
    return weight * input


class Manual_RMSNorm(nn.Module):
    
    def __init__(self, dim, eps, device=None):
        super().__init__()
        self.eps = eps
        self.dim = dim
        self.weight = self.weight = nn.Parameter(torch.ones(dim, device=device))
        
    def forward(self, x):
        return manual_rms_norm(x, self.weight.shape, self.weight, self.eps)


def main():
    
    dtype = torch.float16
    device = 'cuda'
    hidden_size = 10240
    input_size = 10240
    weight_type = torch.float16

    repeats = 1000
    
    x_pt = torch.randn(1, input_size, hidden_size, dtype=dtype, device=device).requires_grad_()
    x = x_pt.detach().clone().requires_grad_()
    x_mixed = x_pt.detach().clone().requires_grad_()
    x_func = x_pt.detach().clone().requires_grad_()

    model_pt = RMSNorm(dim=hidden_size, eps=1e-6).to(device=device, dtype=dtype)
    model = FusedRMSNorm(hidden_size, eps=1e-6).to(device=device, dtype=dtype)
    model_mixed = MixedFusedRMSNorm(hidden_size, eps=1e-6).to(device=device, dtype=dtype) 
    model_manual = Manual_RMSNorm(hidden_size, eps=1e-6).to(device=device, dtype=dtype)
    
    with torch.no_grad():
        model.weight.copy_(model_pt.weight)
        model_mixed.weight.copy_(model_pt.weight)
        model_manual.weight.copy_(model_pt.weight)
    
    output_pt = model_pt(x_pt)
    output = model(x)
    output_mixed = model_mixed(x_mixed)
    output_man = model_manual(x_func)
    
    loss = torch.rand_like(output) / 32
    
    output_pt.backward(loss)
    output.backward(loss)
    output_mixed.backward(loss)
    output_man.backward(loss)
    
    
    print("pytorch = ", model_pt.weight.grad)
    print("fused = ", model.weight.grad)
    print("mixed = ", model_mixed.weight.grad)
    print("man = ", model_manual.weight.grad)

The following is my test results:

pytorch =  tensor([-0.1270, -0.2615, -3.3340,  ..., -0.0049, -2.6680, -3.2207],
       device='cuda:0', dtype=torch.float16)
fused =  tensor([-0.1283, -0.2607, -3.3340,  ..., -0.0045, -2.6680, -3.2207],
       device='cuda:0', dtype=torch.float16)
mixed =  tensor([-0.1283, -0.2607, -3.3340,  ..., -0.0045, -2.6680, -3.2207],
       device='cuda:0', dtype=torch.float16)
man =  tensor([-0.1270, -0.2615, -3.3340,  ..., -0.0049, -2.6680, -3.2207],
       device='cuda:0', dtype=torch.float16)

yingtongxiong avatar Apr 28 '23 08:04 yingtongxiong

The above Manual_RMSNorm seems similar to LLaMa's RMSNorm from https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L43C16-L43C16

This does Cast to Fp32->Normalize with variance -> cast back to BF16 -> multiply with layernorm weights However, if I change the order to - Cast to Fp32->Normalize with variance -> multiply with layernorm weights -> cast back to BF16 , the calculations will be more precise, and the values exactly match apex's FusedLayerNorm.

See the code snippet below -

    # if weight.dtype in [torch.float16, torch.bfloat16]:
    #     input = input.to(weight.dtype)
    return (weight * input).type_as(input)

akhilkedia avatar Jun 29 '23 12:06 akhilkedia

I recommend this can hence be safely closed.

akhilkedia avatar Jun 29 '23 12:06 akhilkedia

I recommend this can hence be safely closed.

I encountered a similar issue where the results of MixedFusedRMSNorm and LLAMA's RMSNorm are inconsistent when applied to the same tensor. I have made the change about: return (weight * input).type_as(input)

SefaZeng avatar Jan 17 '24 02:01 SefaZeng

I recommend this can hence be safely closed.

I encountered a similar issue where the results of MixedFusedRMSNorm and LLAMA's RMSNorm are inconsistent when applied to the same tensor. I have made the change about: return (weight * input).type_as(input)

The MixedFusedRMSNorm outputs:

tensor([ 3.2715e-02, -1.8921e-03,  9.3460e-05,  5.7068e-03, -8.1177e-03, ...

The LLAMA RMSnorm outputs:

tensor([[ 3.2745e-02, -1.8914e-03,  9.3334e-05,  5.6968e-03, -8.1640e-03, ...

SefaZeng avatar Jan 17 '24 02:01 SefaZeng