Patrick Kidger

Results 1451 comments of Patrick Kidger

At first glance it's probably that you're doing `opt_state = optim.init(eqx.filter(model, eqx.is_array_like))` rather than `opt_state = optim.init(eqx.filter(model, eqx.is_array))`. Equinox generally wants you to treat non-arrays as static.

Right! So moving this over was deemed nontrivial as it depends on JAX's own pretty-printing, and I didn't really want to duplicate all of JAX pp + Equinox pp into...

Interesting! This is a fun one. So it's great that you've already been able to isolate this to a particular set of inputs and to a particular iteration of the...

Awesome! I'm glad we have an answer. I'd be very happy to take a PR making this tweak if/when you feel it's appropriate. (I can see there are some ongoing...

Ah, I think this is something JAX still needs better docs for. In response to the various points you raise: - `filter_vmap` definitely shouldn't do anything to do with sharding....

Yeah, I don't see a reason to bump this until we also find ourselves needing to depend on JAX 0.4.31+.

You've forgotten to call `.block_until_ready()`. Equinox will actually call this for you automatically: https://github.com/patrick-kidger/equinox/blob/15a800dd0ab1fc91b033c9305a5fe2f7bf2aecae/equinox/_jit.py#L248 But the others don't do this by default. Equinox does this so that runtime errors are...

You need to specify the tolerances for the root finder for the solver: ```python solver = diffrax.ImplicitEuler(root_finder=diffrax.VeryChord(rtol=1e-8, atol=1e-8)) ``` What's going on here is that by default we set the...

Wrapping into floating point arrays fixes this one: ```python t0 = jnp.array(0.) dt0 = jnp.array(0.05) t1 = jnp.array(1.) y0 = jnp.array(1.0) ``` You're aptly demonstrating that our step-by-step API needs...

I think you should be able to do this by using the existing Optimistix optimizers (no needed to define a new `Abstract{Search,Descent}`!) and wrapping their calls in a `jax.vmap`. If...