torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

How to migrate from scipy's odeint to torchdiffeq, including passing arguments

Open gitGksgk opened this issue 3 years ago • 0 comments

Hi, I'm trying to migrate my code solving ode from scipy to torchdiffeq, believing it is the solution. I have the ode function like

def rhs_unit_vector(z,t,J,K, n,omega):
     # several details, where J, K, n, omega are parameters, and z is the collection of x, y, theta
     return np.concatenate((x_next, y_next, theta_next))

and solve the ode via scipy's odeint like

sols = odeint(f.rhs_unit_vector, z0, t, args=(J,K, n,omega))

it works fine. And now I'm trying to speed it up by solving it via torchdiffeq. I've searched the docs and issues, but found no relevant info to my problem. I got questions:

Q1: how do I pass the args J, K, n, omega ? it keeps reporting that odeint from torchdiffeq takes 3 positional arguments but 7 was given

Q2: why I cannot pass z, the first argument, into ode function? in the code below, I passed z0 as a collection of x, y, theta. But inside the function I keep getting z0 as tensor(0., dtype=torch.float64), while the t printed inside function seems to be the correct z0 passed outside. where did I go wrong?

import torch
from torchdiffeq import odeint 
import numpy as np

def packed_rhs_unit_vector(z,t):
    #Instantiate -- set up 
    print('----------inside func-----------------')
    print("z:",z)
    print("t:",t)

a, dt, T, n, L  = 1, 0.5, 2000, 1000, 1 
t = torch.tensor([dt*i for i in range(int(T/dt))])
x0 = np.random.uniform(-L,L,n)
y0 =np.random.uniform(-L,L,n)
theta0 = np.random.uniform(-np.pi,np.pi,n)

#Do simulation
#tic = time.clock()
z0 = np.array([x0,y0,theta0])
z0 = z0.flatten()
print('---------t---------------')
print(t)
z0 = torch.tensor(z0)
print('----------z0--------------')
print(z0)

sols = odeint(packed_rhs_unit_vector, z0, t)

the whole output is

---------t---------------
tensor([0.0000e+00, 5.0000e-01, 1.0000e+00,  ..., 1.9985e+03, 1.9990e+03,
        1.9995e+03])
----------z0--------------
tensor([ 0.0242, -0.5233, -0.5568,  ...,  0.5798, -0.1276, -2.8584],
       dtype=torch.float64)
insideeeeeeeeeeeee tensor([ 0.0242, -0.5233, -0.5568,  ...,  0.5798, -0.1276, -2.8584],
       dtype=torch.float64) tensor([0.0000e+00, 5.0000e-01, 1.0000e+00,  ..., 1.9985e+03, 1.9990e+03,
        1.9995e+03])
----------inside func-----------------
z: tensor(0., dtype=torch.float64)
t: tensor([ 0.0242, -0.5233, -0.5568,  ...,  0.5798, -0.1276, -2.8584],
       dtype=torch.float64)
----------inside func-----------------
z: tensor(0., dtype=torch.float64)
t: tensor([ 0.0242, -0.5233, -0.5568,  ...,  0.5798, -0.1276, -2.8584],
       dtype=torch.float64)

in which insideeeeeee row is when I dive into the torchdiffeq/_impl/odeint.py and print its y0 and t, which seems to works fine.

so where did it go wrong? Thanks for your time and patience!

gitGksgk avatar May 18 '22 07:05 gitGksgk