torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

Support higher order autodiff?

Open woct0rdho opened this issue 6 years ago • 6 comments

Thanks a lot for your work! However, it seems that backward after grad is not supported yet. Here is a minimal example:

# dy/dt = a b y, t = 0...1
# y1 = y0 exp(a b)
# dy1/da = b y0 exp(a b)
# dy1/dy0 = exp(a b)

import torch
from torch import nn
from torch.autograd import grad
from torchdiffeq import odeint_adjoint as odeint


class Func(nn.Module):
    def __init__(self):
        super(Func, self).__init__()
        self.a = nn.Parameter(torch.tensor(2.0))
        self.b = nn.Parameter(torch.tensor(3.0))

    def forward(self, t, y):
        return self.a * self.b * y


if __name__ == '__main__':
    func = Func()
    y0 = torch.tensor(4.0, requires_grad=True)
    t = torch.tensor([0.0, 1.0])
    y1 = odeint(func, y0, t)[1]
    print(y1)
    y1.backward(retain_graph=True)

    dy1_da = grad(y1, func.a, create_graph=True)[0]
    print(dy1_da)
    dy1_da.backward(retain_graph=True)

    dy1_dy0 = grad(y1, y0, create_graph=True)[0]
    print(dy1_dy0)
    dy1_dy0.backward(retain_graph=True)

Both dy1_da and dy1_dy0 do not have grad_fn, then dy1_da.backward and dy1_dy0.backward throw errors. It would be nice if you could support these operations, then we could build more complex applications on your package.

woct0rdho avatar Feb 06 '19 08:02 woct0rdho

Yes, this is on a TODO for now. https://github.com/rtqichen/torchdiffeq/blob/master/torchdiffeq/_impl/adjoint.py#L31 I'd need to get a bit finicky with pytorch's Function.

For now, I think the non-adjoint version should support higher-order autodiffs.

rtqichen avatar Feb 12 '19 04:02 rtqichen

Hi @rtqichen, thank you for your work! So how can we currently take high order gradients of the integration outputs with respect to the inputs (e.g. initial conditions or else)?

Sceki avatar Jun 02 '22 09:06 Sceki

Hello @rtqichen. Was there a progress in higher order autodiff feature using the adjoint method?

EyalRozenberg1 avatar Jul 03 '23 11:07 EyalRozenberg1

No sorry, zero progress has been made since 2019. If anyone wants to submit a PR for this, I can approve it.

rtqichen avatar Jul 03 '23 13:07 rtqichen

Thanks for your comment, Ricky. Are there any action items that should be taken? Eyal

EyalRozenberg1 avatar Jul 03 '23 14:07 EyalRozenberg1

@rtqichen Thanks for your work. But The GPU memory is lost when I use import torch.autograd as ag. My code is as follows:

for itr in range(1, 5):
        print('iteration: ',itr)
        print("start_memory_allcoated(MB) {}".format(torch.cuda.memory_allocated()/1048576))
        optimizer.zero_grad()
        batch_y0, batch_t, batch_y = get_batch()
        batch_y0.requires_grad  = True
        print('batch_y0.shape: ',batch_y0.shape)
        pred_y = odeint(func, batch_y0, batch_t).to(device)
        

My results are as follows:

iteration:  1
start_memory_allcoated(MB) 0.0146484375
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0244140625
iteration:  2
start_memory_allcoated(MB) 0.0244140625
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.025390625
iteration:  3
start_memory_allcoated(MB) 0.025390625
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0263671875
iteration:  4
start_memory_allcoated(MB) 0.0263671875
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.02734375
(base) [s2608314@node3c01(eddie) networks]$ python test.py
iteration:  1
start_memory_allcoated(MB) 0.0146484375
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0244140625
iteration:  2
start_memory_allcoated(MB) 0.0244140625
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.025390625
iteration:  3
start_memory_allcoated(MB) 0.025390625
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0263671875
iteration:  4
start_memory_allcoated(MB) 0.0263671875
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.02734375
iteration:  5
start_memory_allcoated(MB) 0.02734375
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0283203125
iteration:  6
start_memory_allcoated(MB) 0.0283203125
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.029296875
iteration:  7
start_memory_allcoated(MB) 0.029296875
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0302734375
iteration:  8
start_memory_allcoated(MB) 0.0302734375
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.03125
iteration:  9
start_memory_allcoated(MB) 0.03125
batch_y0.shape:  torch.Size([20, 1, 2])
end_memory_allcoated(MB) 0.0322265625

And it really causes out of memory! Hope to get your reply as soon as possible.

wangmiaowei avatar Feb 25 '24 22:02 wangmiaowei