cgiovanetti
cgiovanetti
I've tried with both CPU and GPU and had the same result. I can't access `sol_at_MT.result`; the code hangs before it finishes assigning values to that variable (setting `throw=False` doesn't...
Thanks for the tip about the debug statement! Apologies about the `float64` line missing--I pulled this snippet from a larger block of code and must have left that line out....
That's odd--I can solve this dummy system with `scipy.integrate.solve_ivp` in a few seconds. I use their built-in method `BDF`, but changing the `PIDController` tolerances to something smaller doesn't seem to...
I've found some other examples of systems that `diffrax` seems to have a hard time solving, in the hopes it might make the issue a little clearer. In particular, here's...
Sorry it's taken so long for me to get back to you, just wanted to be sure this was also helping with the larger set of DEs I'm trying to...
No--we had to go through the other dependencies one-by-one, but still ended up with the same scipy.linalg error at the end of it.
Okay--is diffrax uninstalling/reinstalling JAX because of a version issue then? i.e., the conda installed JAX is too new/old? Or is there some other reason it might not see the conda...
Maybe the sharper question to ask is: if I do not provide the `--no-deps` flag when pip installing diffrax, what version(s) of JAX must I have for JAX not to...
Thanks for the fast reply! Sticking to old jax distributions will work in some cases but in general isn't the ideal workaround for me--I'm working on a mixed CPU/GPU project...
Interesting--this cuts my runtime way down for jax==0.6.2, and has virtually no effect for jax==0.8.1. Thanks for the insight! I'll see if I can get an MWE with a double...