[Question] Neural CDE for regression
Description
I would like to create a neural CDE for regression. For that, I have taken the example from neural CDE for classification and adapted using the content from neural ODE for regression.
I am encountering some issues which I do not know how to solve. I would appreciate if someone could point me in the right direction. Thank you!
Code
import time
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import matplotlib
import matplotlib.pyplot as plt
import optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.mlp = eqx.nn.MLP(
in_size=data_size,
out_size=data_size,
width_size=width_size,
depth=depth,
activation=jnn.tanh,
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(in_size=data_size, out_size=data_size, width_size=width_size, depth=depth, key=ikey)
self.func = Func(data_size, width_size, depth, key=fkey)
def __call__(self, ts, coeffs, evolving_out=False):
# Each sample of data consists of some timestamps `ts`, and some `coeffs`
# parameterising a control path. These are used to produce a continuous-time
# input path `control`.
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = ts[1] - ts[0]
y0 = self.initial(control.evaluate(ts[0]))
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(ts=ts),
)
return solution.ys
# ============================================================================
def _get_data(ts, *, key):
y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)
def f(t, y, args):
x = y / (1 + y)
return jnp.stack([x[1], -x[0]], axis=-1)
solver = diffrax.Tsit5()
dt0 = 0.1
saveat = diffrax.SaveAt(ts=ts)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
)
ys = sol.ys
return ys
def get_data(dataset_size, *, key):
length = 100
ts = jnp.linspace(0, 10, length)
key = jr.split(key, dataset_size)
ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
ts_broadcasted = jnp.broadcast_to(ts, (dataset_size, length))
ys = jnp.concatenate([ts_broadcasted[:, :, None], ys], axis=-1) # time is a channel
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts_broadcasted, ys)
return ts_broadcasted, ys, coeffs
# ============================================================================
def dataloader(arrays, batch_size, *, key):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = jnp.arange(dataset_size)
while True:
perm = jr.permutation(key, indices)
(key,) = jr.split(key, 1)
start = 0
end = batch_size
while end < dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
def main(
dataset_size=256,
batch_size=32,
lr_strategy=(3e-3, 3e-3),
steps_strategy=(500, 500),
length_strategy=(0.1, 1),
width_size=64,
depth=2,
seed=5678,
plot=True,
print_every=100,
):
key = jr.PRNGKey(seed)
data_key, model_key, loader_key = jr.split(key, 3)
ts, ys, coeffs = get_data(dataset_size, key=data_key)
_, length_size, data_size = ys.shape
model = NeuralCDE(data_size, width_size, depth, key=model_key)
# Training loop like normal.
#
# Only thing to notice is that up until step 500 we train on only the first 10% of
# each time series. This is a standard trick to avoid getting caught in a local
# minimum.
@eqx.filter_jit # value_and_grad
def loss(model, ti, yi, coeff_i):
y_pred = jax.vmap(model, in_axes=(None, 0))(ti[0, :], coeff_i)
# MSE without time column
return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)
grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
@eqx.filter_jit
def make_step(data_i, model, opt_state):
ti, yi, *coeff_i = data_i
loss, grads = grad_loss(model, ti, yi, coeff_i)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
optim = optax.adam(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
_ts = ts[:, : int(length_size * length)]
_ys = ys[:, : int(length_size * length)]
_coeffs = tuple(arr[:, :int(length_size * length) - 1] for arr in coeffs)
for step, data_i in zip(
range(steps), dataloader((_ts, _ys) + _coeffs, batch_size, key=loader_key)
):
start = time.time()
loss, model, opt_state = make_step(data_i, model, opt_state)
end = time.time()
if (step % print_every) == 0 or step == steps - 1:
print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")
if plot:
plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
plt.plot(ts, ys[0, :, 1], c="dodgerblue")
sample_coeffs = tuple(c[-1] for c in coeffs)
pred = model(ts, sample_coeffs, evolving_out=True)
plt.plot(ts, pred[:, 0], c="crimson", label="Model")
plt.plot(ts, pred[:, 0], c="crimson")
plt.legend()
plt.tight_layout()
plt.savefig("neural_ode.png")
plt.show()
return ts, ys, model
ts, ys, model = main()
Error
The error originates at line
return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)
The error message is quite large to write it down here. To replicate the error, please run the code above.
Note: My intention is to compute the MSE between the predicted values and the true values. The variables y and y_pred contain the time series values on the first column. Therefore, for the MSE I only use the last two columns.
Specifications
jax 0.4.35 jaxlib 0.4.35 jaxtyping 0.2.34 diffrax 0.6.0 equinox 0.11.8 numpy 2.1.2 optax 0.2.3
You have grad_loss = eqx.filter_value_and_grad(loss, has_aux=True) but you don't actually return any auxiliary variables. Setting that to false yields:
Step: 0, Loss: 0.17582178115844727, Computation time: 13.726179122924805
Step: 100, Loss: 0.012010098434984684, Computation time: 0.04791116714477539
Step: 200, Loss: 0.01128536369651556, Computation time: 0.06833648681640625
Step: 300, Loss: 0.006681683007627726, Computation time: 0.03933405876159668
Step: 400, Loss: 0.008453472517430782, Computation time: 0.034162044525146484
Thank you, that solved the issue.
I have tried different hyperparameter combinations (num. epochs, learning rate, num. layers, etc) but I cannot get as accurate results as with the Neural ODE. I am wondering if there is some problem with my code. Would be possible for you to take a look and verify that my implementation is correct? Thank you.
Updated code:
import time
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import matplotlib
import matplotlib.pyplot as plt
import optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.mlp = eqx.nn.MLP(
in_size=data_size,
out_size=data_size,
width_size=width_size,
depth=depth,
activation=jnn.tanh,
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(in_size=data_size, out_size=data_size, width_size=width_size, depth=depth, key=ikey)
self.func = Func(data_size, width_size, depth, key=fkey)
def __call__(self, ts, coeffs, evolving_out=False):
# Each sample of data consists of some timestamps `ts`, and some `coeffs`
# parameterising a control path. These are used to produce a continuous-time
# input path `control`.
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = ts[1] - ts[0]
y0 = self.initial(control.evaluate(ts[0]))
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=diffrax.SaveAt(ts=ts),
)
return solution.ys
# ============================================================================
def _get_data(ts, *, key):
y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)
def f(t, y, args):
x = y / (1 + y)
return jnp.stack([x[1], -x[0]], axis=-1)
solver = diffrax.Tsit5()
dt0 = 0.1
saveat = diffrax.SaveAt(ts=ts)
sol = diffrax.diffeqsolve(
diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
)
ys = sol.ys
return ys
def get_data(dataset_size, *, key):
length = 100
ts = jnp.linspace(0, 10, length)
key = jr.split(key, dataset_size)
ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
ts_broadcasted = jnp.broadcast_to(ts, (dataset_size, length))
ys = jnp.concatenate([ts_broadcasted[:, :, None], ys], axis=-1) # time is a channel
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts_broadcasted, ys)
return ts_broadcasted, ys, coeffs
# ============================================================================
def dataloader(arrays, batch_size, *, key):
dataset_size = arrays[0].shape[0]
assert all(array.shape[0] == dataset_size for array in arrays)
indices = jnp.arange(dataset_size)
while True:
perm = jr.permutation(key, indices)
(key,) = jr.split(key, 1)
start = 0
end = batch_size
while end < dataset_size:
batch_perm = perm[start:end]
yield tuple(array[batch_perm] for array in arrays)
start = end
end = start + batch_size
def main(
dataset_size=256,
batch_size=64,
lr_strategy=(3e-3, 3e-3),
steps_strategy=(500, 500),
length_strategy=(1, 1),
width_size=64,
depth=2,
seed=5678,
plot=True,
print_every=100,
):
key = jr.PRNGKey(seed)
data_key, model_key, loader_key = jr.split(key, 3)
ts, ys, coeffs = get_data(dataset_size, key=data_key)
_, length_size, data_size = ys.shape
model = NeuralCDE(data_size, width_size, depth, key=model_key)
# Training loop like normal.
#
# Only thing to notice is that up until step 500 we train on only the first 10% of
# each time series. This is a standard trick to avoid getting caught in a local
# minimum.
@eqx.filter_jit # value_and_grad
def loss(model, ti, yi, coeff_i):
y_pred = jax.vmap(model, in_axes=(None, 0))(ti[0, :], coeff_i)
# MSE without time column
return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)
grad_loss = eqx.filter_value_and_grad(loss, has_aux=False)
@eqx.filter_jit
def make_step(data_i, model, opt_state):
ti, yi, *coeff_i = data_i
loss, grads = grad_loss(model, ti, yi, coeff_i)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
optim = optax.adam(lr)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
_ts = ts[:, : int(length_size * length)]
_ys = ys[:, : int(length_size * length)]
_coeffs = tuple(arr[:, :int(length_size * length) - 1] for arr in coeffs)
for step, data_i in zip(
range(steps), dataloader((_ts, _ys) + _coeffs, batch_size, key=loader_key)
):
start = time.time()
loss, model, opt_state = make_step(data_i, model, opt_state)
end = time.time()
if (step % print_every) == 0 or step == steps - 1:
print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")
if plot:
ts = ts[0, :]
plt.plot(ts, ys[0, :, 1], c="dodgerblue", label="Real")
plt.plot(ts, ys[0, :, 2], c="dodgerblue")
sample_coeffs = tuple(c[-1] for c in coeffs)
pred = model(ts, sample_coeffs, evolving_out=True)
plt.plot(ts, pred[:, 1], c="crimson", label="Model")
plt.plot(ts, pred[:, 2], c="crimson")
plt.legend()
plt.tight_layout()
plt.savefig("neural_ode.png")
plt.show()
return ts, ys, coeffs, model
ts, ys, coeffs, model = main()
I'm probably not familiar enough with Neural CDEs to be able to diagnose issues without substantial investigation. I would recommend checking piece by piece to make sure each of the subroutines is operating as expected, e.g. by comparing to specific known solutions on small problems.