Can odeint_adjoint solve parametric ODEs?
Let's say I have some linear system of parametric ODEs:
def f(t, y, params):
a, b, c = params
dydt = torch.zeros_like(y)
dydt[0] = a*y[0] - b*y[1]
dydt[1] = b*y[0] - c*y[1]
return dydt
How do I pass the parameters to odeint/odeint_adjoint? In scipy.integrate.odeint, this would look like this:
odeint(f, y0, t, args=(a,b,c))
can someone please answer this question?
Hi! I'm not one of the developers but I think you can do it this way:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchdiffeq import odeint, odeint_adjoint
class ODEfunc(nn.Module):
def __init__(self, params):
super(ODEfunc, self).__init__()
self.params = params
def forward(self, t, y):
a, b, c = self.params
dydt = torch.zeros_like(y)
dydt[0] = a*y[0] - b*y[1]
dydt[1] = b*y[0] - c*y[1]
return dydt
time = torch.linspace(0.0, 10.0, 100)
params = torch.Tensor([1.0, 2.0, 3.0])
y0 = torch.Tensor([1.5, 0.25])
func = ODEfunc(params)
result = odeint(func, y0, time)
result_adjoint = odeint_adjoint(func, y0, time)
plt.plot(time, result[:, 0], color='tab:blue', zorder=0, label="odeint")
plt.scatter(time, result_adjoint[:, 0], color='tab:blue')
plt.plot(time, result[:, 1], color='tab:red', zorder=0)
plt.scatter(time, result_adjoint[:, 1], color='tab:red')
plt.xlabel('Time')
plt.ylabel('Y')
plt.show()
Output:
Do this approach work when the parameters are not constant, but different for the samples we want to train on?
I had to do something similar. I followed the format from the (ANODE)[https://arxiv.org/abs/1904.01681] paper appending the additional parameters to the function vector. The difference is I make sure the derivative out for those parameters is 0 so they don't get updated.
$$ \frac{d}{dt} \begin{bmatrix} y(t) \\ p \end{bmatrix} = \begin{bmatrix} f( \begin{bmatrix} y(t) \\ p \end{bmatrix} , t) \\ 0 \end{bmatrix} $$
I modified the code sample from @ftavella and produces the same output.
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchdiffeq import odeint, odeint_adjoint
class ODEfunc(nn.Module):
def __init__(self):
super(ODEfunc, self).__init__()
def forward(self, t, y):
a, b, c = y[-3], y[-2], y[-1]
dydt = torch.zeros_like(y)
dydt[0] = a*y[0] - b*y[1]
dydt[1] = b*y[0] - c*y[1]
return dydt
time = torch.linspace(0.0, 10.0, 100)
params = torch.Tensor([1.0, 2.0, 3.0])
y0 = torch.Tensor([1.5, 0.25])
func = ODEfunc()
initial = torch.cat((y0, params), 0)
result = odeint(func, initial, time)
result_adjoint = odeint_adjoint(func, initial, time)
plt.plot(time, result[:, 0], color='tab:blue', zorder=0, label="odeint")
plt.scatter(time, result_adjoint[:, 0], color='tab:blue')
plt.plot(time, result[:, 1], color='tab:red', zorder=0)
plt.scatter(time, result_adjoint[:, 1], color='tab:red')
plt.xlabel('Time')
plt.ylabel('Y')
plt.show()