Owen L
Owen L
If you just want to update one member variable of the module, you can just use a tree at: ```python import jax from jax import numpy as jnp import equinox...
Sure, that would work. Although it seems like width impacts/determines other variables (such as w), but the code you run would work.
I just answer a few issues, all credit goes to Patrick
Looks like this is fixed on main, I see 0.0 as the output for jax and equinox on main. I believe this was a similar error in the jacfwd to...
Augmenting the error message is definitely a good idea (related issues: https://github.com/patrick-kidger/diffrax/issues/461, https://github.com/patrick-kidger/diffrax/issues/446), the core issue currently is that the message isn't very informative about why the terms are failing....
My sort of idea: https://github.com/patrick-kidger/diffrax/pull/478
I recommending checking out Jax's docs on benchmarking https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code, the tldr for this example is that: 1. jax compile times will be longer for first iteration (and are generally excluded...
Yes, this behavior is expected. The python floats are getting marked as static by the filtering that happens before jit. You can make them not static by making them jax...
> ...and it's actually really useful to me when these breaks do occur downstream, because then someone comes here and opens an issue/PR to let us know! With respect to...
Oops, I think I cleaned out that WB account recently. I can see about re-running the experiments if needed. For the 9x9 go, there just so happens to be a...