equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Slow JIT compilation when doing teacher forcing

Open itsakk opened this issue 7 months ago • 1 comments

Hello Patrick,

I recently tried to move from torch to jax/equinox (thanks a lot for your contributions both for diffrax and equinox) and i tried to adapt all my code written in torch to jax/equinox.

I am having trouble trying to implement teacher forcing for my NeuralODE. I was used to do that with pytorch as I remarked it was helping the network to not diverge from the ground truth trajectory.

The issue is that my code takes a lot of time to compile and thus training the network takes way more time than when using torch. It is probably due to the for loop, if my understanding is correct from the different comments you made regarding that type of errors.

I have tried to use jax.lax.scan but I don't find a proper way to do that when doing teacher forcing.

Here is my code, you can consider that my network is just a NODE with a MLP network:

@eqx.filter_value_and_grad
def compute_loss(diff_model, static_model, batch, epsilon, min_op, lambda_0):
    model = eqx.combine(diff_model, static_model)
    outputs = teacher_forcing(model, batch, epsilon)
    y = batch['states']
    mse_loss = jnp.mean((outputs - y) ** 2)
    return mse_loss

def teacher_forcing(model, batch, epsilon):
    t = batch['t']
    y = batch['states']
    env = batch['env']
    
    if epsilon < 1e-3:
        epsilon = 0

    if epsilon == 0:
        res = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t, env, 0)
    else:
        eval_points = np.random.random(len(t[0])) < epsilon
        eval_points[-1] = False
        eval_points = eval_points[1:]
        start_i, end_i = 0, None
        res = []
        for i, eval_point in enumerate(eval_points):
            if eval_point:
                end_i = i + 1
                t_seg = t[:, start_i:end_i + 1]
                res_seg = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t_seg, env, start_i)
                if len(res) == 0:
                    res.append(res_seg)
                else:
                    res.append(res_seg[:, 1:, :])
                start_i = end_i
        t_seg = t[:, start_i:]
        res_seg = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t_seg, env, start_i)
        if len(res) == 0:
            res.append(res_seg)
        else:
            res.append(res_seg[:, 1:, :])
        res = jnp.concatenate(res, axis=1)
    return jnp.moveaxis(res, 1, 2)

@eqx.filter_jit
def train_step(model, filter_spec, batch, epsilon, optim, opt_state, min_op, lambda_0):
    diff_model, static_model = eqx.partition(model, filter_spec)
    loss, grads = compute_loss(diff_model, static_model, batch, epsilon, min_op, lambda_0)
    updates, opt_state = optim.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return loss, model, opt_state

Any suggestions to increase efficiency are greatly appreciated. Thanks a lot!

itsakk avatar Nov 25 '23 15:11 itsakk