Questions about numeric precision of FusedRMSNorm
Hello, I have tested the numeric precision of FusedRMSNorm and MixFusedRMSNorm in two different version respectively. Finally, I found that the gradient of model weights can not keep the same and there are a quiet large difference in weight gradient. Therefore, can you help me to solve it or give some suggestions?
The following is my test implementation:
import copy
import torch
from torch import nn
from apex.normalization.fused_layer_norm import MixedFusedRMSNorm, FusedRMSNorm
def manual_rms_norm(input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape)-1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps)
if weight is None:
return input
# convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(weight.dtype)
return weight * input
class Manual_RMSNorm(nn.Module):
def __init__(self, dim, eps, device=None):
super().__init__()
self.eps = eps
self.dim = dim
self.weight = self.weight = nn.Parameter(torch.ones(dim, device=device))
def forward(self, x):
return manual_rms_norm(x, self.weight.shape, self.weight, self.eps)
def main():
dtype = torch.float16
device = 'cuda'
hidden_size = 10240
input_size = 10240
weight_type = torch.float16
repeats = 1000
x_pt = torch.randn(1, input_size, hidden_size, dtype=dtype, device=device).requires_grad_()
x = x_pt.detach().clone().requires_grad_()
x_mixed = x_pt.detach().clone().requires_grad_()
x_func = x_pt.detach().clone().requires_grad_()
model_pt = RMSNorm(dim=hidden_size, eps=1e-6).to(device=device, dtype=dtype)
model = FusedRMSNorm(hidden_size, eps=1e-6).to(device=device, dtype=dtype)
model_mixed = MixedFusedRMSNorm(hidden_size, eps=1e-6).to(device=device, dtype=dtype)
model_manual = Manual_RMSNorm(hidden_size, eps=1e-6).to(device=device, dtype=dtype)
with torch.no_grad():
model.weight.copy_(model_pt.weight)
model_mixed.weight.copy_(model_pt.weight)
model_manual.weight.copy_(model_pt.weight)
output_pt = model_pt(x_pt)
output = model(x)
output_mixed = model_mixed(x_mixed)
output_man = model_manual(x_func)
loss = torch.rand_like(output) / 32
output_pt.backward(loss)
output.backward(loss)
output_mixed.backward(loss)
output_man.backward(loss)
print("pytorch = ", model_pt.weight.grad)
print("fused = ", model.weight.grad)
print("mixed = ", model_mixed.weight.grad)
print("man = ", model_manual.weight.grad)
The following is my test results:
pytorch = tensor([-0.1270, -0.2615, -3.3340, ..., -0.0049, -2.6680, -3.2207],
device='cuda:0', dtype=torch.float16)
fused = tensor([-0.1283, -0.2607, -3.3340, ..., -0.0045, -2.6680, -3.2207],
device='cuda:0', dtype=torch.float16)
mixed = tensor([-0.1283, -0.2607, -3.3340, ..., -0.0045, -2.6680, -3.2207],
device='cuda:0', dtype=torch.float16)
man = tensor([-0.1270, -0.2615, -3.3340, ..., -0.0049, -2.6680, -3.2207],
device='cuda:0', dtype=torch.float16)
The above Manual_RMSNorm seems similar to LLaMa's RMSNorm from https://github.com/facebookresearch/llama/blob/57b0eb62de0636e75af471e49e2f1862d908d9d8/llama/model.py#L43C16-L43C16
This does
Cast to Fp32->Normalize with variance -> cast back to BF16 -> multiply with layernorm weights
However, if I change the order to -
Cast to Fp32->Normalize with variance -> multiply with layernorm weights -> cast back to BF16
, the calculations will be more precise, and the values exactly match apex's FusedLayerNorm.
See the code snippet below -
# if weight.dtype in [torch.float16, torch.bfloat16]:
# input = input.to(weight.dtype)
return (weight * input).type_as(input)
I recommend this can hence be safely closed.
I recommend this can hence be safely closed.
I encountered a similar issue where the results of MixedFusedRMSNorm and LLAMA's RMSNorm are inconsistent when applied to the same tensor. I have made the change about: return (weight * input).type_as(input)
I recommend this can hence be safely closed.
I encountered a similar issue where the results of MixedFusedRMSNorm and LLAMA's RMSNorm are inconsistent when applied to the same tensor. I have made the change about:
return (weight * input).type_as(input)
The MixedFusedRMSNorm outputs:
tensor([ 3.2715e-02, -1.8921e-03, 9.3460e-05, 5.7068e-03, -8.1177e-03, ...
The LLAMA RMSnorm outputs:
tensor([[ 3.2745e-02, -1.8914e-03, 9.3334e-05, 5.6968e-03, -8.1640e-03, ...