diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

[Question] Neural CDE for regression

Open suargi opened this issue 1 year ago • 3 comments

Description

I would like to create a neural CDE for regression. For that, I have taken the example from neural CDE for classification and adapted using the content from neural ODE for regression.

I am encountering some issues which I do not know how to solve. I would appreciate if someone could point me in the right direction. Thank you!

Code

import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import matplotlib
import matplotlib.pyplot as plt
import optax

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.tanh,
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, fkey, lkey = jr.split(key, 3)
        self.initial = eqx.nn.MLP(in_size=data_size, out_size=data_size, width_size=width_size, depth=depth, key=ikey)
        self.func = Func(data_size, width_size, depth, key=fkey)
        
    def __call__(self, ts, coeffs, evolving_out=False):
        # Each sample of data consists of some timestamps `ts`, and some `coeffs`
        # parameterising a control path. These are used to produce a continuous-time
        # input path `control`.
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = ts[1] - ts[0]
        y0 = self.initial(control.evaluate(ts[0]))
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

# ============================================================================
def _get_data(ts, *, key):
    y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys


def get_data(dataset_size, *, key):
    length = 100
    ts = jnp.linspace(0, 10, length)
    key = jr.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    ts_broadcasted = jnp.broadcast_to(ts, (dataset_size, length))
    ys = jnp.concatenate([ts_broadcasted[:, :, None], ys], axis=-1) # time is a channel
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts_broadcasted, ys)
    return ts_broadcasted, ys, coeffs

# ============================================================================

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

def main(
    dataset_size=256,
    batch_size=32,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 500),
    length_strategy=(0.1, 1),
    width_size=64,
    depth=2,
    seed=5678,
    plot=True,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key = jr.split(key, 3)

    ts, ys, coeffs = get_data(dataset_size, key=data_key)
    _, length_size, data_size = ys.shape

    model = NeuralCDE(data_size, width_size, depth, key=model_key)

    # Training loop like normal.
    #
    # Only thing to notice is that up until step 500 we train on only the first 10% of
    # each time series. This is a standard trick to avoid getting caught in a local
    # minimum.

    @eqx.filter_jit # value_and_grad
    def loss(model, ti, yi, coeff_i):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti[0, :], coeff_i)
        # MSE without time column
        return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)

    @eqx.filter_jit
    def make_step(data_i, model, opt_state):
        ti, yi, *coeff_i = data_i
        loss, grads = grad_loss(model, ti, yi, coeff_i)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adam(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[:, : int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        _coeffs = tuple(arr[:, :int(length_size * length) - 1] for arr in coeffs)
        for step, data_i in zip(
            range(steps), dataloader((_ts, _ys) + _coeffs, batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(data_i, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        plt.plot(ts, ys[0, :, 0], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 1], c="dodgerblue")
        sample_coeffs = tuple(c[-1] for c in coeffs)
        pred = model(ts, sample_coeffs, evolving_out=True)
        plt.plot(ts, pred[:, 0], c="crimson", label="Model")
        plt.plot(ts, pred[:, 0], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, model


ts, ys, model = main()

Error

The error originates at line

return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)

The error message is quite large to write it down here. To replicate the error, please run the code above. Note: My intention is to compute the MSE between the predicted values and the true values. The variables y and y_pred contain the time series values on the first column. Therefore, for the MSE I only use the last two columns.

Specifications

jax 0.4.35 jaxlib 0.4.35 jaxtyping 0.2.34 diffrax 0.6.0 equinox 0.11.8 numpy 2.1.2 optax 0.2.3

suargi avatar Nov 04 '24 14:11 suargi

You have grad_loss = eqx.filter_value_and_grad(loss, has_aux=True) but you don't actually return any auxiliary variables. Setting that to false yields:

Step: 0, Loss: 0.17582178115844727, Computation time: 13.726179122924805
Step: 100, Loss: 0.012010098434984684, Computation time: 0.04791116714477539
Step: 200, Loss: 0.01128536369651556, Computation time: 0.06833648681640625
Step: 300, Loss: 0.006681683007627726, Computation time: 0.03933405876159668
Step: 400, Loss: 0.008453472517430782, Computation time: 0.034162044525146484

lockwo avatar Nov 04 '24 15:11 lockwo

Thank you, that solved the issue.

I have tried different hyperparameter combinations (num. epochs, learning rate, num. layers, etc) but I cannot get as accurate results as with the Neural ODE. I am wondering if there is some problem with my code. Would be possible for you to take a look and verify that my implementation is correct? Thank you.

Updated code:

import time

import diffrax
import equinox as eqx  # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import matplotlib
import matplotlib.pyplot as plt
import optax

class Func(eqx.Module):
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.tanh,
            final_activation=jnn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.mlp(y)

class NeuralCDE(eqx.Module):
    initial: eqx.nn.MLP
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        ikey, fkey, lkey = jr.split(key, 3)
        self.initial = eqx.nn.MLP(in_size=data_size, out_size=data_size, width_size=width_size, depth=depth, key=ikey)
        self.func = Func(data_size, width_size, depth, key=fkey)
        
    def __call__(self, ts, coeffs, evolving_out=False):
        # Each sample of data consists of some timestamps `ts`, and some `coeffs`
        # parameterising a control path. These are used to produce a continuous-time
        # input path `control`.
        control = diffrax.CubicInterpolation(ts, coeffs)
        term = diffrax.ControlTerm(self.func, control).to_ode()
        solver = diffrax.Tsit5()
        dt0 = ts[1] - ts[0]
        y0 = self.initial(control.evaluate(ts[0]))
        solution = diffrax.diffeqsolve(
            term,
            solver,
            ts[0],
            ts[-1],
            dt0,
            y0,
            stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
            saveat=diffrax.SaveAt(ts=ts),
        )
        return solution.ys

# ============================================================================
def _get_data(ts, *, key):
    y0 = jr.uniform(key, (2,), minval=-0.6, maxval=1)

    def f(t, y, args):
        x = y / (1 + y)
        return jnp.stack([x[1], -x[0]], axis=-1)

    solver = diffrax.Tsit5()
    dt0 = 0.1
    saveat = diffrax.SaveAt(ts=ts)
    sol = diffrax.diffeqsolve(
        diffrax.ODETerm(f), solver, ts[0], ts[-1], dt0, y0, saveat=saveat
    )
    ys = sol.ys
    return ys


def get_data(dataset_size, *, key):
    length = 100
    ts = jnp.linspace(0, 10, length)
    key = jr.split(key, dataset_size)
    ys = jax.vmap(lambda key: _get_data(ts, key=key))(key)
    ts_broadcasted = jnp.broadcast_to(ts, (dataset_size, length))
    ys = jnp.concatenate([ts_broadcasted[:, :, None], ys], axis=-1) # time is a channel
    coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts_broadcasted, ys)
    return ts_broadcasted, ys, coeffs

# ============================================================================

def dataloader(arrays, batch_size, *, key):
    dataset_size = arrays[0].shape[0]
    assert all(array.shape[0] == dataset_size for array in arrays)
    indices = jnp.arange(dataset_size)
    while True:
        perm = jr.permutation(key, indices)
        (key,) = jr.split(key, 1)
        start = 0
        end = batch_size
        while end < dataset_size:
            batch_perm = perm[start:end]
            yield tuple(array[batch_perm] for array in arrays)
            start = end
            end = start + batch_size

def main(
    dataset_size=256,
    batch_size=64,
    lr_strategy=(3e-3, 3e-3),
    steps_strategy=(500, 500),
    length_strategy=(1, 1),
    width_size=64,
    depth=2,
    seed=5678,
    plot=True,
    print_every=100,
):
    key = jr.PRNGKey(seed)
    data_key, model_key, loader_key = jr.split(key, 3)

    ts, ys, coeffs = get_data(dataset_size, key=data_key)
    _, length_size, data_size = ys.shape

    model = NeuralCDE(data_size, width_size, depth, key=model_key)

    # Training loop like normal.
    #
    # Only thing to notice is that up until step 500 we train on only the first 10% of
    # each time series. This is a standard trick to avoid getting caught in a local
    # minimum.

    @eqx.filter_jit # value_and_grad
    def loss(model, ti, yi, coeff_i):
        y_pred = jax.vmap(model, in_axes=(None, 0))(ti[0, :], coeff_i)
        # MSE without time column
        return jnp.mean((yi[:, :, 1:] - y_pred[:, :, 1:]) ** 2)

    grad_loss = eqx.filter_value_and_grad(loss, has_aux=False)

    @eqx.filter_jit
    def make_step(data_i, model, opt_state):
        ti, yi, *coeff_i = data_i
        loss, grads = grad_loss(model, ti, yi, coeff_i)
        updates, opt_state = optim.update(grads, opt_state)
        model = eqx.apply_updates(model, updates)
        return loss, model, opt_state

    for lr, steps, length in zip(lr_strategy, steps_strategy, length_strategy):
        optim = optax.adam(lr)
        opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
        _ts = ts[:, : int(length_size * length)]
        _ys = ys[:, : int(length_size * length)]
        _coeffs = tuple(arr[:, :int(length_size * length) - 1] for arr in coeffs)
        for step, data_i in zip(
            range(steps), dataloader((_ts, _ys) + _coeffs, batch_size, key=loader_key)
        ):
            start = time.time()
            loss, model, opt_state = make_step(data_i, model, opt_state)
            end = time.time()
            if (step % print_every) == 0 or step == steps - 1:
                print(f"Step: {step}, Loss: {loss}, Computation time: {end - start}")

    if plot:
        ts = ts[0, :]
        plt.plot(ts, ys[0, :, 1], c="dodgerblue", label="Real")
        plt.plot(ts, ys[0, :, 2], c="dodgerblue")
        sample_coeffs = tuple(c[-1] for c in coeffs)
        pred = model(ts, sample_coeffs, evolving_out=True)
        plt.plot(ts, pred[:, 1], c="crimson", label="Model")
        plt.plot(ts, pred[:, 2], c="crimson")
        plt.legend()
        plt.tight_layout()
        plt.savefig("neural_ode.png")
        plt.show()

    return ts, ys, coeffs, model


ts, ys, coeffs, model = main()

suargi avatar Nov 04 '24 15:11 suargi

I'm probably not familiar enough with Neural CDEs to be able to diagnose issues without substantial investigation. I would recommend checking piece by piece to make sure each of the subroutines is operating as expected, e.g. by comparing to specific known solutions on small problems.

lockwo avatar Nov 04 '24 15:11 lockwo