Patrick Kidger
Patrick Kidger
Thanks for the report! It turns out that the upcoming #660 will solve this - we're just about to land that. You should be able to install Diffrax from #660...
I recommend differentiating with respect to raw JAX arrays (and internally wrapping this into a LinearInterpolation object). It should then be fairly clear what the meaning of the gradients are,...
Thanks for the report! This makes sense to me. I've tweaked things in #682.
So Equinox looks to avoid special-casing any module methods. In particular it may be the case that someone already has a method called `tree_unflatten` etc, for some purpose unrelated to...
So I'm really leaning against adding something like that to `eqx.Module`. Part of the design thesis of `eqx.Module`, as compared to `jax.tree_util`, is that custom flatten/unflatten functions are error-prone and...
I'm afraid Signatory is no longer supported. You can probably make it run, but you may have to work to do so yourself :)
I think you should be able to do this by just wrapping a diffeqsolve + algebraic solution in a `lax.scan`. :)
Yup, exactly. I'm not sure I understand your question about `SaveAt`, I'm afraid. If you'd like to save the full solution at specific times `ts` and at a`t1` then you...
Do you have a MWE? That aside I'd also recommend raising this issue on the JAX issue tracker (with a reproducible example) if you're seeing severe performance drops.
Oh, this is super weird. I'm able to reproduce this behaviour. Here's a smaller MWE: ```python import os os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false" import diffrax import equinox as eqx import jax jax.config.update("jax_enable_x64",...