torchmd-net icon indicating copy to clipboard operation
torchmd-net copied to clipboard

[Feature request] Training with torch.compile

Open RaulPPelaez opened this issue 11 months ago • 0 comments

Currently it is not possible to run backwards twice with torch.compile. For instance, this code fails:

from torch import nn, Tensor


class Model(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        output = input * input
        return output


model = Model()
model = torch.compile(model, backend="inductor")
input = torch.randn(10, requires_grad=True)
y = model(input)
dy = torch.autograd.grad(
    y, input, grad_outputs=torch.ones_like(y), create_graph=True, retain_graph=True
)[0]
ddy = torch.autograd.grad(dy, input, grad_outputs=torch.ones_like(dy))[0]

With this error

$ python test_model.py 
Traceback (most recent call last):
  File "/home/raul/work/bcn/torchmd-net/tests/test_model.py", line 266, in <module>
    ddy = torch.autograd.grad(dy, input, grad_outputs=torch.ones_like(dy))[0]
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/__init__.py", line 412, in grad
    result = _engine_run_backward(
             ^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/autograd/function.py", line 301, in apply
    return user_fn(self, *args)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/raul/miniforge3/envs/torchnightly/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 877, in backward
    raise RuntimeError(
RuntimeError: torch.compile with aot_autograd does not currently support double backward

Even when using the latest pytorch nightly: pytorch 2.3.0.dev20240313 py3.11_cpu_0 pytorch-nightly

This is a well known limitation of compile https://github.com/pytorch/pytorch/issues/91469

TorchMD-Net uses backpropagation to compute forces from energies, which means that double backpropagation is required to train with forces.

Thus, one cannot currently do this, as it will trigger the same error as above:

def test_compile_double_backwards():
    pl.seed_everything(12345)
    output_model = "Scalar"
    derivative = True
    args = load_example_args(
        "tensornet",
        remove_prior=True,
        output_model=output_model,
        derivative=derivative,
    )
    model = create_model(args)
    model = torch.compile(model, backend="inductor")
    z, pos, batch = create_example_batch(n_atoms=5)
    pos.requires_grad_(True)
    y, dy = model(z, pos, batch)
    dy.sum().backward()

I am opening this issue to keep track of the feature.

RaulPPelaez avatar Mar 18 '24 10:03 RaulPPelaez