Patrick Kidger

Results 267 comments of Patrick Kidger

Ech, looks like this runs afoul of #176 as well. Which version of JAX are you using? For the purposes of debugging this I'll downgrade. (And #176 will be fixed...

Okay, I think I've tracked this down. It's because without the diffusion term, your computation is actually unbatched: your only batched input is `key`, but this is unused. So JAX...

Upgrading giving a performance boost -- this is because I've been working to improve the efficiency of Diffrax :) (And there's more stuff coming in just over the horizon: leave...

In terms of how I tracked this down: nope, no jaxprs. They're sadly not that helpful for debugging anything Diffrax-related. Differential equation solvers are large and complicated enough that they...

So yeah, there's some definite differences here due to whether we're evaluating just `drift`, just `diff`, or both. In particular these two networks are also of different sizes, and I...

You might find my other project [jaxtyping](https://github.com/google/jaxtyping) interesting. This is able to handle other array/tensor types -- at minimum it is tested to be able to handle JAX+numpy+pytorch+tensorflow. This is...

So it's not a documented feature, but [Equinox](http://github.com/patrick-kidger/equinox) actually has a tree-math like sublibrary built-in, which can be used to do this kind of multi-axis stuff. To set the scene,...