When comparing Thunder Torch Executor to Torch Eager, the ResNet18 gradients are not close for FP32.
Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.
🐛 Bug
To Reproduce
Steps to reproduce the behavior:
- modify the test case
https://github.com/Lightning-AI/lightning-thunder/blob/6320b2f0cad03dff49c5141b6731587451711a4d/thunder/tests/test_inplace_functionalization.py#L184
to
if train and executor == TorchExecutor: - Run
pytest thunder/tests/test_inplace_functionalization.py -k test_parse_resnet18_torch_cuda_float32[True]see error:
if train and executor == TorchExecutor: # and dtype == thunder.float64:
torch_grads = torch.autograd.grad(out1, ref_model.parameters(), torch.ones_like(out1))
thunder_grads = torch.autograd.grad(out2, jitted.parameters(), torch.ones_like(out2))
> torch.testing.assert_close(torch_grads, thunder_grads)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 9405 / 9408 (100.0%)
E Greatest absolute difference: 0.09205560386180878 at index (4, 1, 5, 0) (up to 1e-05 allowed)
E Greatest relative difference: 10.715060234069824 at index (39, 1, 3, 0) (up to 1.3e-06 allowed)
E
E The failure occurred for item [0]
thunder/tests/test_inplace_functionalization.py:187: AssertionError
=================================================== short test summary info ===================================================
FAILED thunder/tests/test_inplace_functionalization.py::test_parse_resnet18_torch_cuda_float32[True] - AssertionError: Tensor-likes are not close!
run this script
import torch
import torchvision
import os
os.environ["NVIDIA_TF32_OVERRIDE"]="0"
os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
torch.manual_seed(42)
import random
random.seed(42)
torch.use_deterministic_algorithms(True)
model = torchvision.models.resnet18(weights=None).to(device="cuda", dtype=torch.float32)
x = torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cuda", requires_grad=True)
print(torch.autograd.gradcheck(model, (x,)))
has GradcheckError:
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:135.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
root@9340b8cf8485:/wayan/lightning-thunder# python thunder/tests/testtrace.py
/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py:920: UserWarning: Input #0 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex.
warnings.warn(
/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py:768: UserWarning: Attempting to run cuBLAS, but there was no cutext... (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:135.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
Traceback (most recent call last):
File "/wayan/lightning-thunder/thunder/tests/testtrace.py", line 15, in <module>
print(torch.autograd.gradcheck(model, (x,)))
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py", line 2053, in gradcheck
return _gradcheck_helper(**args)
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py", line 2082, in _gradcheck_helper
_gradcheck_real_imag(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py", line 1492, in _gradcheck_real_imag
gradcheck_fn(
File "/usr/local/lib/python3.10/dist-packages/torch/autograd/gradcheck.py", line 1633, in _slow_gradcheck
raise GradcheckError(
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ 0.1043, 0.0522, 0.0298, ..., -0.0149, -0.0447, -0.0596],
[ 0.1043, -0.0447, 0.0298, ..., 0.1341, 0.0298, -0.0894],
[ 0.1192, -0.1043, 0.0000, ..., -0.0596, -0.0149, 0.0596],
...,
[ 0.1788, -0.0820, -0.0149, ..., 0.0224, 0.1341, -0.0596],
[ 0.0000, -0.2459, 0.1639, ..., 0.0894, -0.1267, -0.0596],
[ 0.1043, -0.0075, -0.1043, ..., 0.0894, -0.0969, 0.0000]],
device='cuda:0')
analytical:tensor([[ 2.5345e-04, -2.1945e-04, 7.5599e-05, ..., -1.5271e-04,
-1.6242e-04, 4.9330e-04],
[-2.0753e-04, 5.1979e-04, 8.1766e-05, ..., -2.5569e-04,
-2.4477e-04, 1.9414e-04],
[-1.2130e-04, 1.2330e-04, -2.3220e-04, ..., 2.7823e-04,
2.9276e-04, -1.9633e-04],
...,
[-7.2192e-05, -9.3861e-05, -4.2660e-05, ..., -7.6299e-05,
-6.6284e-05, 1.2527e-05],
[ 4.0978e-05, 2.3847e-05, 2.6876e-05, ..., -2.3141e-05,
-1.0444e-06, 1.5903e-05],
[-4.5947e-05, -1.3556e-05, -8.9267e-05, ..., 6.1379e-05,
2.5143e-05, 2.6964e-05]], device='cuda:0')
with float64 it can pass
https://pytorch.org/docs/stable/generated/torch.autograd.gradcheck.gradcheck.html says
Note The default values are designed for input of double precision. This check will likely fail if input is of less precision, e.g., FloatTensor.
however, the values above seem very far off, so I'm wondering whether the operators we call have some bug / input assumptions not satisfied etc.
Comparison of fp64 and fp32 results:
Torch eager fp64 vs thunder torchex fp32: Mismatched elements: 9248 / 9408 (98.3%) Greatest absolute difference: 0.00013152781110292722 at index (33, 2, 5, 6) (up to 1e-07 allowed) Greatest relative difference: 0.04311691189015545 at index (25, 2, 4, 4) (up to 1e-07 allowed)
Torch eager fp64 vs torch eager fp32: Mismatched elements: 9242 / 9408 (98.2%) Greatest absolute difference: 9.72576008653192e-05 at index (44, 2, 4, 1) (up to 1e-07 allowed) Greatest relative difference: 0.049338367753292256 at index (25, 2, 4, 4) (up to 1e-07 allowed)
Torch eager fp32 vs thunder torchex fp32: Mismatched elements: 5468 / 9408 (58.1%) Greatest absolute difference: 0.00011849403381347656 at index (2, 2, 5, 0) (up to 1e-05 allowed) Greatest relative difference: 0.027340635657310486 at index (41, 2, 6, 6) (up to 1.3e-06 allowed)
script
import torch
import torchvision
import os
os.environ["NVIDIA_TF32_OVERRIDE"]="0"
# os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"
torch.manual_seed(42)
import random
random.seed(42)
# torch.use_deterministic_algorithms(True)
model_fp32 = torchvision.models.resnet18(weights=None).to(device="cuda", dtype=torch.float32)
torch_model_fp32 = torchvision.models.resnet18(weights=None).to(device="cuda", dtype=torch.float32)
torch_model_fp64 = torchvision.models.resnet18(weights=None).to(device="cuda", dtype=torch.float64)
model_fp32=model_fp32.train()
torch_model_fp32 = torch_model_fp32.train()
torch_model_fp64 = torch_model_fp64.train()
model_fp32.load_state_dict(torch_model_fp64.state_dict())
torch_model_fp32.load_state_dict(torch_model_fp64.state_dict())
x = torch.randn((1, 3, 224, 224), dtype=torch.float64, device="cuda")
import thunder
jitted_fp32 = thunder.jit(model_fp32, executors=[thunder.pytorch_executor])
out1 = jitted_fp32(x.to(torch.float32))
out2 = torch_model_fp32(x.to(torch.float32))
out3 = torch_model_fp64(x)
torch.testing.assert_close(out1, out2)
thunder_grads_fp32 = torch.autograd.grad(out1, jitted_fp32.parameters(), torch.ones_like(out1))
torch_grads_fp32 = torch.autograd.grad(out2, torch_model_fp32.parameters(), torch.ones_like(out2))
torch_grads_fp64 = torch.autograd.grad(out3, torch_model_fp64.parameters(), torch.ones_like(out3))
tmp1 = [x.to(torch.float64) for x in thunder_grads_fp32]
tmp2 = [x.to(torch.float64) for x in torch_grads_fp32]
torch.testing.assert_close(thunder_grads_fp32, torch_grads_fp32)
# print("thunder:")
# torch.testing.assert_close(tmp1, torch_grads_fp64)
# print("torch:")
# torch.testing.assert_close(tmp2, torch_grads_fp64)
triage review:
- seems to be something in the backward pass
- some discrepancy w.r.t. tf32 in forward and backward?
- but also happens in fp64, so seems likely that's it not tf32-specific, or rather not even type-specific
- maybe try to override modules with identity and see if that helps narrow down which module is failing