Owen L
Owen L
They are both failing the term check, differently I believe because EulerHeun.term_structure is a MultiTerm, and thus goes into a different branch of the check and term_contr_kwargs expects a tuple...
Equinox should (in most cases) be installable wherever jax is installable. So if jax works on Jetson Xavier then equinox could probably too. Seems like jax has aarch64 wheels now...
If you want to use forward mode to save memory there is `diffrax.ForwardMode()`
> I did not find an option that simply stores every intermediate solver state once and performs a single reverse sweep with zero recomputation (time O(n), memory O(n)); the closest...
I'm not sure there's really an equinox solution to this, this is a sort of a JAX thing. Jax likes vmap array shapes to be the same (more specifically, it...
Two main things, first, this happens because `filter_jit` is actually transforming the function into a Partial Module with the member variable being part of it. So it goes from ```...
Partial is just a pytree modifier, it can be used at any point. If you encountered errors while training, share a MRE
yes, because each layer modifies the state, which is why it is returned
Inference mode just checks based on `lambda leaf: hasattr(leaf, "inference")` basically so it doesn't have to be filtered
Does this only happen with batch norm? I would assume JAX wants the pytree to be JAX arrays, and thus my go to would be to filter /combine (as opposed...