Patrick Kidger
Patrick Kidger
Haha, you're definitely getting into the weeds here. So as `step` is a traced value then I'm guessing what you're getting at is that we could wrap the `cond_fn` in...
So the behavior here is actually intentional, in that: 1. the safeguard is important to catch a common class of error: forgetting to use the output from `.set`. 2. flattening+unflattening...
I've been noodling on this idea, and I've realised a potential problem with it. Namely, that there is a lot of JAX code out there that basically relies on the...
Thanks for the question! So this is a somewhat subtle point, and arguably a minor (easy to work around) bug in JAX itself. First of all, here's a smaller repro...
At least here, the choice between a lambda wrapper and a `jax.tree_util.Partial` wrapper doesn't matter at all. If you're after a more general rule to avoid footguns, then: - For...
Hey there! So I think this is one where we'd probably follow JAX's lead. If they add a `jax.bmap`, we'll add an `eqx.filter_bmap` :) On the idea of a `filter_scan`...
Hey there! It's great to hear about how you're using Equinox. As for the caching you're seeing, this is expected behaviour: `equinox.filter_jit` caches based on array shape+dtype, and on the...
This might be https://github.com/jax-ml/jax/issues/30517
This seems reasonable to me, but it would be a slightly breaking change, in that it does mean that this field can no longer be replaced using `eqx.tree_at`. (Not the...
This is a feature that I'd be happy to take a PR on :) FWIW we currently test `pickle` and `cloudpickle` for peforming serialisation: https://github.com/patrick-kidger/jaxtyping/blob/main/test/test_serialisation.py . If you can identify...