diffrax
diffrax copied to clipboard
Adding support for complex dtypes
Hi,
I finally have the time to actually finish this PR. I've been trying to start from scratch instead of going over all of the changes, and currently the solvers are able to return a solution, it's just the wrong solution. I'm not sure how to debug this, any suggestions?
Some suggestions:
- Go through the guts of the integration code, adding
jax.experimental.host_callback.id_print
statements to see what value each array takes, and where things go wrong. - Try using e.g.
Tsit5().step
directly instead of thediffeqsolve
interface. Check if that works as you expect. - Try using only constant step sizes, and not adaptive step sizing, if you aren't already. Just trying to get the simplest possible thing working first.
It seems that jax.experimental.host_callback.id_print
is doing nothing when I run in debug mode. Should I use it directly in a print statement?
The default tolerances for allclose
are rtol=1e-05
and atol=1e-08
. With stepsize dt0=0.1
this kind of precision is too high I think, in test_basic
for a similar test dt0=0.01
is used and for the allclose
rtol=1e-02
and atol=1e-02
are chosen. Your approach fixes the issue for the Euler solver, however there is another place in the Runge-Kutta solver that needs to be assigned a correct dtype. I'll check and test if that is all and probably will submit a PR if I can manage to find and tweak all the lines that are causing issues with complex numbers.
Closing in favour of #330.