How to migrate from scipy's odeint to torchdiffeq, including passing arguments
Hi, I'm trying to migrate my code solving ode from scipy to torchdiffeq, believing it is the solution. I have the ode function like
def rhs_unit_vector(z,t,J,K, n,omega):
# several details, where J, K, n, omega are parameters, and z is the collection of x, y, theta
return np.concatenate((x_next, y_next, theta_next))
and solve the ode via scipy's odeint like
sols = odeint(f.rhs_unit_vector, z0, t, args=(J,K, n,omega))
it works fine. And now I'm trying to speed it up by solving it via torchdiffeq. I've searched the docs and issues, but found no relevant info to my problem. I got questions:
Q1: how do I pass the args J, K, n, omega ? it keeps reporting that odeint from torchdiffeq takes 3 positional arguments but 7 was given
Q2: why I cannot pass z, the first argument, into ode function? in the code below, I passed z0 as a collection of x, y, theta. But inside the function I keep getting z0 as tensor(0., dtype=torch.float64), while the t printed inside function seems to be the correct z0 passed outside. where did I go wrong?
import torch
from torchdiffeq import odeint
import numpy as np
def packed_rhs_unit_vector(z,t):
#Instantiate -- set up
print('----------inside func-----------------')
print("z:",z)
print("t:",t)
a, dt, T, n, L = 1, 0.5, 2000, 1000, 1
t = torch.tensor([dt*i for i in range(int(T/dt))])
x0 = np.random.uniform(-L,L,n)
y0 =np.random.uniform(-L,L,n)
theta0 = np.random.uniform(-np.pi,np.pi,n)
#Do simulation
#tic = time.clock()
z0 = np.array([x0,y0,theta0])
z0 = z0.flatten()
print('---------t---------------')
print(t)
z0 = torch.tensor(z0)
print('----------z0--------------')
print(z0)
sols = odeint(packed_rhs_unit_vector, z0, t)
the whole output is
---------t---------------
tensor([0.0000e+00, 5.0000e-01, 1.0000e+00, ..., 1.9985e+03, 1.9990e+03,
1.9995e+03])
----------z0--------------
tensor([ 0.0242, -0.5233, -0.5568, ..., 0.5798, -0.1276, -2.8584],
dtype=torch.float64)
insideeeeeeeeeeeee tensor([ 0.0242, -0.5233, -0.5568, ..., 0.5798, -0.1276, -2.8584],
dtype=torch.float64) tensor([0.0000e+00, 5.0000e-01, 1.0000e+00, ..., 1.9985e+03, 1.9990e+03,
1.9995e+03])
----------inside func-----------------
z: tensor(0., dtype=torch.float64)
t: tensor([ 0.0242, -0.5233, -0.5568, ..., 0.5798, -0.1276, -2.8584],
dtype=torch.float64)
----------inside func-----------------
z: tensor(0., dtype=torch.float64)
t: tensor([ 0.0242, -0.5233, -0.5568, ..., 0.5798, -0.1276, -2.8584],
dtype=torch.float64)
in which insideeeeeee row is when I dive into the torchdiffeq/_impl/odeint.py and print its y0 and t, which seems to works fine.
so where did it go wrong? Thanks for your time and patience!