Patrick Kidger
Patrick Kidger
> Should I consider downcasting the automatically generated jax scalar arrays to primitive floats If I understand your problem correctly, then this would be an acceptable solution. Another option is...
> Differentiating with respect to the init=True arguments is exactly what I want to do. Ah got it! In that case what you have written is incorrect, and is totally...
Side note @lockwo -- I think this interaction is highlighting that in JAX, using `init=False` has maybe zero use cases, and is just a footgun. We could add a warning...
I think you'll find that the cached property just gets recomputed, as it won't (shouldn't) be preserved across flattening+unflattening; in particular when crossing JIT or grad barriers.
If it's specifically around autodiff then I still hold out hope that JAX will eventually come to support forward-through-`custom_vjp`, and that we can do away with the huge amounts of...
This looks pretty good to me! It's probably worth tweaking things slightly to have a single JIT wrapping the whole thing (including your final `outer_grads = eqx.filter_grad(...`), but otherwise I...
The output of a `eqx.filter_custom_jvp` must still consist only of JAX types. In this case you're returning the output of `new_inner_model`, which still has its static componets. (I think)
Hi, thanks for the report! Getting an `IndexError` suggests that it may actually be an error within JAX's own error-checking machinery. Either way, can you try minimising this down to...
Interesting! This Equinox-only MWE here is super useful. As another preliminary result on my side, I can avoid this error by dropping back to JAX 0.4.35, so it seems that...
This bit: ```python class ScalarContainer(eqx.Module, strict=True): def __init__(self, scalars: dict[str, dict]): for name, data in scalars.items(): object.__setattr__(self, name, ScalarParameter(data)) ``` is definitely not okay! Equinox modules are frozen dataclasses, and...