torchdiffeq icon indicating copy to clipboard operation
torchdiffeq copied to clipboard

Can odeint_adjoint solve parametric ODEs?

Open mayar-shahin opened this issue 2 years ago • 4 comments

Let's say I have some linear system of parametric ODEs:

def f(t, y, params):
    a, b, c = params 
    dydt = torch.zeros_like(y)
    dydt[0] = a*y[0] - b*y[1]
    dydt[1] = b*y[0] - c*y[1]
    return dydt

How do I pass the parameters to odeint/odeint_adjoint? In scipy.integrate.odeint, this would look like this:

odeint(f, y0, t, args=(a,b,c))

mayar-shahin avatar Apr 07 '23 05:04 mayar-shahin

can someone please answer this question?

iranroman avatar Jun 11 '23 23:06 iranroman

Hi! I'm not one of the developers but I think you can do it this way:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchdiffeq import odeint, odeint_adjoint


class ODEfunc(nn.Module):
    def __init__(self, params):
        super(ODEfunc, self).__init__()
        self.params = params
    
    def forward(self, t, y):
        a, b, c = self.params
        dydt = torch.zeros_like(y)
        dydt[0] = a*y[0] - b*y[1]
        dydt[1] = b*y[0] - c*y[1]
        return dydt

time = torch.linspace(0.0, 10.0, 100)
params = torch.Tensor([1.0, 2.0, 3.0])
y0 = torch.Tensor([1.5, 0.25])

func = ODEfunc(params)
result = odeint(func, y0, time)
result_adjoint = odeint_adjoint(func, y0, time)

plt.plot(time, result[:, 0], color='tab:blue', zorder=0, label="odeint")
plt.scatter(time, result_adjoint[:, 0], color='tab:blue')

plt.plot(time, result[:, 1], color='tab:red', zorder=0)
plt.scatter(time, result_adjoint[:, 1], color='tab:red')

plt.xlabel('Time')
plt.ylabel('Y')
plt.show()

Output: download

ftavella avatar Aug 03 '23 23:08 ftavella

Do this approach work when the parameters are not constant, but different for the samples we want to train on?

Mr-Markovian avatar Oct 15 '24 14:10 Mr-Markovian

I had to do something similar. I followed the format from the (ANODE)[https://arxiv.org/abs/1904.01681] paper appending the additional parameters to the function vector. The difference is I make sure the derivative out for those parameters is 0 so they don't get updated.

$$ \frac{d}{dt} \begin{bmatrix} y(t) \\ p \end{bmatrix} = \begin{bmatrix} f( \begin{bmatrix} y(t) \\ p \end{bmatrix} , t) \\ 0 \end{bmatrix} $$

I modified the code sample from @ftavella and produces the same output.

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchdiffeq import odeint, odeint_adjoint


class ODEfunc(nn.Module):
    def __init__(self):
        super(ODEfunc, self).__init__()
    
    def forward(self, t, y):
        a, b, c = y[-3], y[-2], y[-1]
        dydt = torch.zeros_like(y)
        dydt[0] = a*y[0] - b*y[1]
        dydt[1] = b*y[0] - c*y[1]
        return dydt

time = torch.linspace(0.0, 10.0, 100)
params = torch.Tensor([1.0, 2.0, 3.0])
y0 = torch.Tensor([1.5, 0.25])

func = ODEfunc()
initial = torch.cat((y0, params), 0)

result = odeint(func, initial, time)
result_adjoint = odeint_adjoint(func, initial, time)

plt.plot(time, result[:, 0], color='tab:blue', zorder=0, label="odeint")
plt.scatter(time, result_adjoint[:, 0], color='tab:blue')

plt.plot(time, result[:, 1], color='tab:red', zorder=0)
plt.scatter(time, result_adjoint[:, 1], color='tab:red')

plt.xlabel('Time')
plt.ylabel('Y')
plt.show()

psv4 avatar Jul 02 '25 15:07 psv4