diffrax
diffrax copied to clipboard
Solving with complex initialization
When trying to run a basic example:
from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp
from jax.lax import cond
Array.set_default_backend('jax')
def f(t, y, args):
return -y
term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.], dtype=complex)
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
I receive this warning: ComplexWarning: Casting complex values to real discards the imaginary part
, which goes away if I change the type of y0
to float
.
Moreover, when running something else with a complex y0
, there was an error TypeError: nextafter does not accept dtype complex64 at position 0. Accepted dtypes at position 0 are subtypes of floating.
Does diffrax not support complex y0
s? If so is there any plan to enable this? Although we can convert the complex elements to 2d arrays, it would make the process a lot smoother if we could just use complex elements.
Diffrax doesn't support complexes at the moment. This was a concious choice on my part, so that other aspects of Diffrax could be prioritised.
My expectation is that there's probably just a few lines of code, that are currently assuming reals, that need tweaking. I'd welcome a PR on this.
From what I looked, it seems that this line: https://github.com/patrick-kidger/diffrax/blob/cec091c5e4cc4311f64ae3aa09a371db5fe766ee/diffrax/integrate.py#L771
ys = jax.tree_map(lambda y: jnp.full((out_size,) + jnp.shape(y), jnp.inf), y0)
is what causing complex numbers to be cast as floats.
What do you think @patrick-kidger ? Are there other places that need changing? If it is only a matter of changing the dtype here to be complex, I'll open a PR.
I haven't looked into this so I don't know how many places might need changing. If you're happy to dig into this I'd still be happy to accept a PR on this. (Including tests, to be sure that this doesn't revert in the future.) My expectation is that it's probably just a few lines that need tweaking.
What's the status of supporting complex dtypes in diffrax
? I see that PR #112 was closed before merging.
BTW (in case it helps anyone looking at this thread), as a temporary workaround, I used the following isomorphism [1] between complex-valued and real matrices/vectors:

That said, it seems that this is less optimal than supporting complex dtypes directly.
[1] N. Leung, M. Abdelhafez, J. Koch, and D. Schuster, Speedup for quantum optimal control from automatic differentiation based on graphics processing units, Phys. Rev. A 95, 042318 (2017).
I was not able to continue with this PR (see the comment above).
Right now Diffrax doesn't support complex dtypes. I'd quite like for them to work, but realistically it's low down my own priority list.
I would be very happy to accept PRs that make the appropriate tweaks.