Owen L

Results 338 comments of Owen L

> I think this interaction is highlighting that in JAX, using init=False has maybe zero use cases, and is just a footgun. We could add a warning for this? WDYT?...

Relevant previous discussion: https://github.com/patrick-kidger/equinox/issues/949

> I'd be happy to have a cross-link if think it's worth it! :) sure I will open one > (Whilst we're here, I'm also conscious that our current rope...

Not exactly the same for a checkpointed while loop, but shows the same behavior so I figure it's probably the same underlying issue: https://github.com/jax-ml/jax/issues/31282

It does seem to be fixed by not used the jnp array in shape structure, but with a fresh colab notebook pip installed with diffrax (with jax 0.7.2) it does...

Most of my examples were CPU focused (although there was some GPU work in https://github.com/jax-ml/jax/issues/20968). What usually solved my issues was simply disabling the new CPU thunk runtime for XLA,...

You have `grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)` but you don't actually return any auxiliary variables. Setting that to false yields: ``` Step: 0, Loss: 0.17582178115844727, Computation time: 13.726179122924805 Step: 100, Loss:...

I'm probably not familiar enough with Neural CDEs to be able to diagnose issues without substantial investigation. I would recommend checking piece by piece to make sure each of the...

Yea, that's a good point. It has less effect on the `batch` mode of the PR (since the weighting is all based on the current batch), but will still impact...

Code seemed to work fine with a fresh install on colab. Are there more details to your setup that might be relevant?