Patrick Kidger
Patrick Kidger
No. As above this is intended -- the values at which you save may not be the values at which you evaluate the vector field. So there's no way to...
Can you give the traceback and message for the error you obtain, and the versions of JAX and Equinox you are using? With JAX 0.4.34 and Equinox 0.11.8 I am...
At least right now this is intended. It's just tricky to handle edge-cases around mixed types within the same object (e.g. `vector=(f32[...], f64[...])`), or backprop, or complex numbers.
Welcome to JAX and Equinox! :) You can check recompilations using [`eqx.debug.assert_max_traces`.](https://docs.kidger.site/equinox/api/debug/#equinox.debug.assert_max_traces). This will also give you which argument caused recompilations. Note that recompilation is triggered by passing new static...
Personally I'm -1 on this. I don't think Marimo should seek to reinvent every orthogonal piece of tooling -- this has always been one of the greatest weaknesses of Jupyter...
Ah -- this must be a mypy vs pyright difference. Happy to take a PR on this :)
Fixed in #822 :)
From your description is seems possible that you've gotten two arguments the wrong way around, so that you are passing a `Linear` layer in as the input `x`. I suspect...
IIUC then the `updates, opt_state = optim.update(grads, opt_state, model_)` line is missing the `value`, `grad`, and `value_fn` arguments. :)
I'm not sure I'm afraid :) This is a detail specific to Optax, not Equinox! The above is the limit of my knowledge with their library!