torchdeq
torchdeq copied to clipboard
Custom autograd fails with torchdeq in eval mode
It's a very nieche problem, but tripped me over big time :')
Issue
For model.eval()
, z_pred
will not have tracked gradients (z_pred.requires_gradient==False
).
For custom torch.autograd this will lead to an error: RuntimeError: One of the differentiated Tensors does not require grad
.
Minimal example
import torch
import torchdeq
from torchdeq import get_deq
from torchdeq.norm import apply_norm, reset_norm
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.layer = torch.nn.Linear(10, 10)
# deq
self.deq = get_deq()
apply_norm(self.layer, 'weight_norm')
def implicit_layer(self, x):
return self.layer(x)
def forward(self, x, pos):
z = torch.zeros_like(x)
reset_norm(self.layer)
f = lambda z: self.f(z, pos)
z_pred, info = self.deq(self.implicit_layer, z)
# if model.eval() -> z_pred[-1].requires_grad is False!
energy = z_pred[-1]
forces = -1 * (
torch.autograd.grad(
energy,
# diff with respect to pos
# if you get 'One of the differentiated Tensors appears to not have been used in the graph'
# then because pos is not 'used' to calculate the energy
pos,
grad_outputs=torch.ones_like(energy),
create_graph=True,
# allow_unused=True,
)[0]
)
return energy, forces
def run(model, eval=False):
if eval:
model.eval()
else:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for step in range(10):
x = torch.randn(10, 10)
pos = torch.randn(10, 3)
energy, forces = model(x, pos)
# loss
optimizer.zero_grad()
energy_target = torch.randn(10, 1)
energy_loss = torch.nn.functional.mse_loss(energy, energy_target)
force_target = torch.randn(10, 3)
force_loss = torch.nn.functional.mse_loss(forces, force_target)
loss = energy_loss + force_loss
if not eval:
loss.backward()
optimizer.step()
return True
if __name__ == '__main__':
model = MyModel()
success = run(model, eval=False)
print(f'train success: {success}')
success = run(model, eval=True)
print(f'eval success: {success}')
While model.train()
it will work perfectly well. For model.eval()
we get the error: RuntimeError: One of the differentiated Tensors does not require grad
.
Desired behaviour
A flag to set such that z_pred[-1].requires_grad
is always True
, even when model.eval()
.
self.deq = get_deq(grad_in_eval=True)