torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

Unable to use vmap on a function containing the ode solver

Open adam-hartshorne opened this issue 2 years ago • 1 comments

If you have say a solve method

def f(t,x):
 ....

def solve(y0):
    t_eval = torch.linspace(0.0, 1.0, 2)
    traj = ode_int(f, y0, t_eval)
    return traj[-1]

def other_func(X):
    mapped_traj = vmap(solve)(X)

where X is for example shape (n,m,d)

This will result in a RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this

I believe the solution is now provided in the nightly builds, as described here.

https://pytorch.org/docs/master/notes/extending.func.html

Which requires a small alteration to core methods.

adam-hartshorne avatar Jan 16 '23 12:01 adam-hartshorne