torchdiffeq
torchdiffeq copied to clipboard
torchdiffeq is slower than scipy's solve_ivp
Here's a sample script that solves the 1D heat equation discretized using the method of lines. torchdiffeq turned out to be ~7x slower, which is not what I expected. I was expecting similar performance at worst (not even, since torchdiffeq was run on GPU). Am I missing something? Thank you so much!
import time
import torch
from torchdiffeq import odeint
import matplotlib.pyplot as plt
from scipy.integrate import solve_ivp
print(f"GPU available: {torch.cuda.is_available()}")
def func(t, y, dy):
N = len(y)
h = 1/(N-1)
idx = np.arange(1, N-1)
dy[idx] = (y[idx-1] - 2*y[idx] + y[idx+1]) / h**2
return dy
N = 300
y0 = torch.ones(N)
y0[0] = y0[-1] = 0.0
tspan = (0, 0.05)
t = torch.linspace(*tspan, 10)
dy = torch.zeros_like(y0)
start = time.time()
sol = odeint(lambda t, y: func(t, y, dy), y0, t)
print(f"torchdiffeq: {time.time()-start:.2f} s")
start = time.time()
dy = np.zeros_like(y0)
odefunc = lambda t, y: func(t, y, dy)
sol2 = solve_ivp(odefunc, tspan, y0.numpy(), t_eval=t, method="RK45")
print(f"solve_ivp: {time.time()-start:.2f} s")
GPU available: True
torchdiffeq: 8.37 s
solve_ivp: 1.16 s
Not remotely and expert, but maybe the improvement appears only to compare the backpropagation speed and error using both methods. And also appears that you're not taking advantage of the adjoint way to solve the odeint.