torchmd-net
torchmd-net copied to clipboard
[Feature request] Training with torch.compile
Currently it is not possible to run backwards twice with torch.compile. For instance, this code fails:
from torch import nn, Tensor
class Model(nn.Module):
def forward(self, input: Tensor) -> Tensor:
output = input * input
return output
model = Model()
model = torch.compile(model, backend="inductor")
input = torch.randn(10, requires_grad=True)
y = model(input)
dy = torch.autograd.grad(
y, input, grad_outputs=torch.ones_like(y), create_graph=True, retain_graph=True
)[0]
ddy = torch.autograd.grad(dy, input, grad_outputs=torch.ones_like(dy))[0]
With this error
$ python test_model.py
Traceback (most recent call last):
File "/home/raul/work/bcn/torchmd-net/tests/test_model.py", line 266, in <module>
ddy = torch.autograd.grad(dy, input, grad_outputs=torch.ones_like(dy))[0]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/__init__.py", line 412, in grad
result = _engine_run_backward(
^^^^^^^^^^^^^^^^^^^^^
File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/function.py", line 301, in apply
return user_fn(self, *args)
^^^^^^^^^^^^^^^^^^^^
File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 877, in backward
raise RuntimeError(
RuntimeError: torch.compile with aot_autograd does not currently support double backward
Even when using the latest pytorch nightly:
pytorch 2.3.0.dev20240313 py3.11_cpu_0 pytorch-nightly
This is a well known limitation of compile https://github.com/pytorch/pytorch/issues/91469
TorchMD-Net uses backpropagation to compute forces from energies, which means that double backpropagation is required to train with forces.
Thus, one cannot currently do this, as it will trigger the same error as above:
def test_compile_double_backwards():
pl.seed_everything(12345)
output_model = "Scalar"
derivative = True
args = load_example_args(
"tensornet",
remove_prior=True,
output_model=output_model,
derivative=derivative,
)
model = create_model(args)
model = torch.compile(model, backend="inductor")
z, pos, batch = create_example_batch(n_atoms=5)
pos.requires_grad_(True)
y, dy = model(z, pos, batch)
dy.sum().backward()
I am opening this issue to keep track of the feature.