TransformerEngine
TransformerEngine copied to clipboard
Linear in Float32 is inaccurate
I'm seeing some pretty large divergences in my model when I port to transformer engine. While tracking these down (checking if I'm crazy or TE is crazy), I'm pretty sure the Float32 implementation of Linear isn't working right:
from transformer_engine.pytorch import Linear as TELinear
from torch.nn import Linear as TorchLinear
import torch
for dtype in [torch.float16, torch.bfloat16, torch.float32]:
torch_linear = TorchLinear(2, 2, dtype=dtype, device='cuda', bias=False)
te_linear = TELinear(2, 2, params_dtype=dtype, device='cuda', bias=False)
state_dict = torch_linear.state_dict()
state_dict['_extra_state'] = None
te_linear.load_state_dict(state_dict)
input = torch.randn(2, 2, dtype=dtype, device='cuda')
print(torch_linear(input) - te_linear(input))
"""
prints
tensor([[0., 0.],
[0., 0.]], device='cuda:0', dtype=torch.float16,
grad_fn=<SubBackward0>)
tensor([[0., 0.],
[0., 0.]], device='cuda:0', dtype=torch.bfloat16,
grad_fn=<SubBackward0>)
tensor([[-3.9619e-04, 1.4237e-04],
[ 3.4451e-05, 1.4246e-05]], device='cuda:0', grad_fn=<SubBackward0>)
"""
reproduces on both main and stable.
Hi @c0g. TE by default uses TF32 when performing FP32 Linears, while pyTorch by default does not. This is the reason of the difference you are seeing in this unit test (you can check that by setting the environment variable NVIDIA_TF32_OVERRIDE to 0 in order to completely disable TF32 - you should see that the outputs become the same). That said, I would be very surprised if this was the reason of the divergence of your model (still, try with that env variable to see whether you get better results that way) - could you share some more information about it so that we can better assist you in solving it?
I can't currently share the model arch (I hate myself for saying that), but it's a fairly basic transformer. To be clear I didn't mean training diverges, rather loading my weights into the TE version gives different answers and I'm trying to track down sources of error. Thanks for this, I have TF32 turned on in my main training tasks but always forget about it - I got caught by this before as well, can't believe I forgot it!
One thing I have noticed is the numerics checks in TE have quite broad bounds, e.g. 5e-2/5e-3 feels high for a linear: https://github.com/NVIDIA/TransformerEngine/blob/30cad990d09fce3c37951d09c6ec085c1216a313/tests/pytorch/test_numerics.py#L876.
Well, considering that all of the computation here are using types with smaller mantissa (TF32 in the case of FP32, FP16 and especially BF16) this tolerance is actually not so surprising. Keep in mind that the error increases with the number of the elements in the sum and even in your 2x2 matrix multiplication test the differences reached 4e-4
.
That said, I agree we should be careful about tolerances not being tight enough to catch actual errors.