equinox
equinox copied to clipboard
Slow JIT compilation when doing teacher forcing
Hello Patrick,
I recently tried to move from torch to jax/equinox (thanks a lot for your contributions both for diffrax and equinox) and i tried to adapt all my code written in torch to jax/equinox.
I am having trouble trying to implement teacher forcing for my NeuralODE. I was used to do that with pytorch as I remarked it was helping the network to not diverge from the ground truth trajectory.
The issue is that my code takes a lot of time to compile and thus training the network takes way more time than when using torch. It is probably due to the for loop, if my understanding is correct from the different comments you made regarding that type of errors.
I have tried to use jax.lax.scan but I don't find a proper way to do that when doing teacher forcing.
Here is my code, you can consider that my network is just a NODE with a MLP network:
@eqx.filter_value_and_grad
def compute_loss(diff_model, static_model, batch, epsilon, min_op, lambda_0):
model = eqx.combine(diff_model, static_model)
outputs = teacher_forcing(model, batch, epsilon)
y = batch['states']
mse_loss = jnp.mean((outputs - y) ** 2)
return mse_loss
def teacher_forcing(model, batch, epsilon):
t = batch['t']
y = batch['states']
env = batch['env']
if epsilon < 1e-3:
epsilon = 0
if epsilon == 0:
res = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t, env, 0)
else:
eval_points = np.random.random(len(t[0])) < epsilon
eval_points[-1] = False
eval_points = eval_points[1:]
start_i, end_i = 0, None
res = []
for i, eval_point in enumerate(eval_points):
if eval_point:
end_i = i + 1
t_seg = t[:, start_i:end_i + 1]
res_seg = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t_seg, env, start_i)
if len(res) == 0:
res.append(res_seg)
else:
res.append(res_seg[:, 1:, :])
start_i = end_i
t_seg = t[:, start_i:]
res_seg = jax.vmap(model, in_axes = (0, 0, 0, None))(y, t_seg, env, start_i)
if len(res) == 0:
res.append(res_seg)
else:
res.append(res_seg[:, 1:, :])
res = jnp.concatenate(res, axis=1)
return jnp.moveaxis(res, 1, 2)
@eqx.filter_jit
def train_step(model, filter_spec, batch, epsilon, optim, opt_state, min_op, lambda_0):
diff_model, static_model = eqx.partition(model, filter_spec)
loss, grads = compute_loss(diff_model, static_model, batch, epsilon, min_op, lambda_0)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return loss, model, opt_state
Any suggestions to increase efficiency are greatly appreciated. Thanks a lot!