torchdiffeq
torchdiffeq copied to clipboard
Unable to use vmap on a function containing the ode solver
If you have say a solve method
def f(t,x):
....
def solve(y0):
t_eval = torch.linspace(0.0, 1.0, 2)
traj = ode_int(f, y0, t_eval)
return traj[-1]
def other_func(X):
mapped_traj = vmap(solve)(X)
where X is for example shape (n,m,d)
This will result in a RuntimeError: functorch functions (vmap, grad, vjp, etc.) currently do not support the use of autograd.Function. Please rewrite your function to not use autograd.Function while we work on fixing this
I believe the solution is now provided in the nightly builds, as described here.
https://pytorch.org/docs/master/notes/extending.func.html
Which requires a small alteration to core methods.