torchdiffeq
torchdiffeq copied to clipboard
Support higher order autodiff?
Thanks a lot for your work! However, it seems that backward
after grad
is not supported yet. Here is a minimal example:
# dy/dt = a b y, t = 0...1
# y1 = y0 exp(a b)
# dy1/da = b y0 exp(a b)
# dy1/dy0 = exp(a b)
import torch
from torch import nn
from torch.autograd import grad
from torchdiffeq import odeint_adjoint as odeint
class Func(nn.Module):
def __init__(self):
super(Func, self).__init__()
self.a = nn.Parameter(torch.tensor(2.0))
self.b = nn.Parameter(torch.tensor(3.0))
def forward(self, t, y):
return self.a * self.b * y
if __name__ == '__main__':
func = Func()
y0 = torch.tensor(4.0, requires_grad=True)
t = torch.tensor([0.0, 1.0])
y1 = odeint(func, y0, t)[1]
print(y1)
y1.backward(retain_graph=True)
dy1_da = grad(y1, func.a, create_graph=True)[0]
print(dy1_da)
dy1_da.backward(retain_graph=True)
dy1_dy0 = grad(y1, y0, create_graph=True)[0]
print(dy1_dy0)
dy1_dy0.backward(retain_graph=True)
Both dy1_da
and dy1_dy0
do not have grad_fn
, then dy1_da.backward
and dy1_dy0.backward
throw errors. It would be nice if you could support these operations, then we could build more complex applications on your package.
Yes, this is on a TODO for now. https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py#L31 I'd need to get a bit finicky with pytorch's Function.
For now, I think the non-adjoint version should support higher-order autodiffs.
Hi @rtqichen, thank you for your work! So how can we currently take high order gradients of the integration outputs with respect to the inputs (e.g. initial conditions or else)?
Hello @rtqichen. Was there a progress in higher order autodiff feature using the adjoint method?
No sorry, zero progress has been made since 2019. If anyone wants to submit a PR for this, I can approve it.
Thanks for your comment, Ricky. Are there any action items that should be taken? Eyal
@rtqichen Thanks for your work. But The GPU memory is lost when I use import torch.autograd as ag. My code is as follows:
for itr in range(1, 5):
print('iteration: ',itr)
print("start_memory_allcoated(MB) {}".format(torch.cuda.memory_allocated()/1048576))
optimizer.zero_grad()
batch_y0, batch_t, batch_y = get_batch()
batch_y0.requires_grad = True
print('batch_y0.shape: ',batch_y0.shape)
pred_y = odeint(func, batch_y0, batch_t).to(device)
My results are as follows:
iteration: 1
start_memory_allcoated(MB) 0.0146484375
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0244140625
iteration: 2
start_memory_allcoated(MB) 0.0244140625
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.025390625
iteration: 3
start_memory_allcoated(MB) 0.025390625
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0263671875
iteration: 4
start_memory_allcoated(MB) 0.0263671875
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.02734375
(base) [s2608314@node3c01(eddie) networks]$ python test.py
iteration: 1
start_memory_allcoated(MB) 0.0146484375
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0244140625
iteration: 2
start_memory_allcoated(MB) 0.0244140625
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.025390625
iteration: 3
start_memory_allcoated(MB) 0.025390625
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0263671875
iteration: 4
start_memory_allcoated(MB) 0.0263671875
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.02734375
iteration: 5
start_memory_allcoated(MB) 0.02734375
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0283203125
iteration: 6
start_memory_allcoated(MB) 0.0283203125
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.029296875
iteration: 7
start_memory_allcoated(MB) 0.029296875
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0302734375
iteration: 8
start_memory_allcoated(MB) 0.0302734375
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.03125
iteration: 9
start_memory_allcoated(MB) 0.03125
batch_y0.shape: torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0322265625
And it really causes out of memory! Hope to get your reply as soon as possible.