torchani
torchani copied to clipboard
torch.jit.script profile guided optimisations produce errors in aev_computer gradients
Hi, I have found that with pytorch 1.13 and 2.0 (not with pytorch<=1.12) the torch.jit.script profile guided optimisations (that are on by default) cause significant errors in the position gradients calculated via backpropagation of aev_computer when using a CUDA device. This is demonstrated in issue https://github.com/openmm/openmm-ml/issues/50.
An example is shown below, manually turning off the jit optimizations gives accurate forces:
from torchani.neurochem import parse_neurochem_resources, Constants
from torchani.aev import AEVComputer
import torch
import numpy as np
class Model(torch.nn.Module):
def __init__(self, device):
super(Model, self).__init__()
info_file_path='ani-2x_8x.info'
const_file, _,_,_ = parse_neurochem_resources(info_file_path)
consts = Constants(const_file)
self.aev_computer = AEVComputer(**consts)
self.aev_computer.to(device)
def forward(self, species, positions):
incoords = positions
inspecies = species
aev = self.aev_computer((inspecies.unsqueeze(0), incoords.unsqueeze(0)))
sumaevs = torch.mean(aev.aevs)
return sumaevs
## setup
N=100
species = torch.randint(0, 7, (N,), device="cuda")
pos = np.random.random((N, 3))
for optimize in [True, False]:
print("JIT optimize = ", optimize)
torch._C._jit_set_profiling_executor(optimize)
torch._C._jit_set_profiling_mode(optimize)
model = Model("cuda")
model = torch.jit.script(model)
grads=[]
for i in range(10):
incoords = torch.tensor(pos, dtype=torch.float32, requires_grad=True, device="cuda")
result = model(species, incoords)
result.backward(retain_graph=True)
grad = incoords.grad
grads.append(grad.cpu().numpy())
print(i,"max percentage error: ",np.max(100.0*np.abs((grads[0]-grads[-1])/grads[0])))
output I get on an RTX3090 is:
JIT optimize = True
Downloading ANI model parameters ...
0 max percentage error: 0.0
1 max percentage error: 0.00055674225
2 max percentage error: 217.80972
3 max percentage error: 217.80959
4 max percentage error: 217.81003
5 max percentage error: 217.80975
6 max percentage error: 217.80972
7 max percentage error: 217.81082
8 max percentage error: 217.80956
9 max percentage error: 217.81024
JIT optimize = False
0 max percentage error: 0.0
1 max percentage error: 0.0003876826
2 max percentage error: 0.0002178617
3 max percentage error: 0.00021537923
4 max percentage error: 0.0005815239
5 max percentage error: 0.0010768962
6 max percentage error: 0.00017895782
7 max percentage error: 0.00035465648
8 max percentage error: 0.00039845158
9 max percentage error: 0.00018266498
I have found a workaround to remove the errors is to replace a **
operation with a torch.float_power
: https://github.com/aiqm/torchani/commit/172b6fe85d3ab2acd3faa7a025b5aded22f2537c,
Thanks for reporting the issue!
This is a problem of NVFuser. A bug report has been filed at https://github.com/pytorch/pytorch/issues/84510
The minimal reproducible example I extracted from the angular function is the following:
def angular_terms(Rca: float, ShfZ: Tensor, EtaA: Tensor, Zeta: Tensor,
ShfA: Tensor, vectors12: Tensor) -> Tensor:
vectors12 = vectors12.view(2, -1, 3, 1, 1, 1, 1)
cos_angles = vectors12.prod(0).sum(1)
ret = (cos_angles + ShfZ) * Zeta * ShfA * 2
return ret.flatten(start_dim=1)
Replace a ** operation with a torch.float_power will not solve the root cause of this problem.
At this moment, I would recommend disabling NVFuser by running the following:
torch._C._jit_set_nvfuser_enabled(False)
This will change to NNC fuser (https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#fusers) instead of nvfuser, which I tested is working correctly.