Patrick Kidger
Patrick Kidger
At first glance it's probably that you're doing `opt_state = optim.init(eqx.filter(model, eqx.is_array_like))` rather than `opt_state = optim.init(eqx.filter(model, eqx.is_array))`. Equinox generally wants you to treat non-arrays as static.
Right! So moving this over was deemed nontrivial as it depends on JAX's own pretty-printing, and I didn't really want to duplicate all of JAX pp + Equinox pp into...
Interesting! This is a fun one. So it's great that you've already been able to isolate this to a particular set of inputs and to a particular iteration of the...
Awesome! I'm glad we have an answer. I'd be very happy to take a PR making this tweak if/when you feel it's appropriate. (I can see there are some ongoing...
Ah, I think this is something JAX still needs better docs for. In response to the various points you raise: - `filter_vmap` definitely shouldn't do anything to do with sharding....
Yeah, I don't see a reason to bump this until we also find ourselves needing to depend on JAX 0.4.31+.
You've forgotten to call `.block_until_ready()`. Equinox will actually call this for you automatically: https://github.com/patrick-kidger/equinox/blob/15a800dd0ab1fc91b033c9305a5fe2f7bf2aecae/equinox/_jit.py#L248 But the others don't do this by default. Equinox does this so that runtime errors are...
You need to specify the tolerances for the root finder for the solver: ```python solver = diffrax.ImplicitEuler(root_finder=diffrax.VeryChord(rtol=1e-8, atol=1e-8)) ``` What's going on here is that by default we set the...
Wrapping into floating point arrays fixes this one: ```python t0 = jnp.array(0.) dt0 = jnp.array(0.05) t1 = jnp.array(1.) y0 = jnp.array(1.0) ``` You're aptly demonstrating that our step-by-step API needs...
I think you should be able to do this by using the existing Optimistix optimizers (no needed to define a new `Abstract{Search,Descent}`!) and wrapping their calls in a `jax.vmap`. If...