Complex input in diffeqsolve with PIDController
Hi, I encountered an issue while using the PIDController. When using complex input in diffeqsolve, I cannot change the coefficient of the PIDController. The default coefficients works fine but when I try to change I get an error leading me to believe that the time variable is converted to complex type somewhere. Here is a MWE to reproduce the error
# %% ==========================================================================
# Imports
# =============================================================================
import jax.numpy as jnp
import diffrax as dx
from jaxtyping import Scalar
# %% ==========================================================================
# Smallest working example, solving a complex ODE with diffrax
# ODE to solve is dy/dt = iy with y a complex number
# Setting up the ODE
# =============================================================================
def vector_field(t: Scalar, y: Scalar, *args):
dotY = 1j * y
return dotY
tsave = jnp.linspace(0.0, 10.0, 1000)
sim_time = tsave[-1]
solver = dx.Tsit5()
saveat = dx.SaveAt(ts=tsave)
y0 = 1.0j
# %% ==========================================================================
# Solve using the default PIDController coefficients
# =============================================================================
stepsize_controller = dx.PIDController(rtol=1e-6, atol=1e-6, pcoeff=0, icoeff=1, dcoeff=0.0)
term = dx.ODETerm(vector_field = vector_field)
res_dx = dx.diffeqsolve(term, solver, t0=0.0, t1=sim_time, dt0=0.01, y0=y0, saveat=saveat,
stepsize_controller=stepsize_controller)
# %% ==========================================================================
# Solve using the non-zero P coefficient
# =============================================================================
stepsize_controller = dx.PIDController(rtol=1e-6, atol=1e-6, pcoeff=0.3, icoeff=0.3, dcoeff=0.0)
term = dx.ODETerm(vector_field = vector_field)
res_dx = dx.diffeqsolve(term, solver, t0=0.0, t1=sim_time, dt0=0.01, y0=y0, saveat=saveat,
stepsize_controller=stepsize_controller)
## Changing the coefficient in the PID controller seems to raise the error
ValueError: `body_fun` must have the same input and output structure. Difference is:
State(
y=c128[],
tprev=f64[],
- tnext=f64[],
+ tnext=c128[],
made_jump=bool[],
solver_state=(bool[], c128[]),
- controller_state=(bool[], bool[], f64[], c128[], c128[]),
+ controller_state=(bool[], bool[], c128[], c128[], c128[]),
result=EnumerationItem(
_value=i32[],
_enumeration=<class 'diffrax._solution.RESULTS'>
),
num_steps=i64[],
num_accepted_steps=i64[],
num_rejected_steps=i64[],
save_state=SaveState(
saveat_ts_index=i64[],
ts=_Buffer(
_array=f64[1000],
_pred=bool[],
_tag=<object object at 0x2e4dc0850>,
_makes_false_steps=False
),
ys=_Buffer(
_array=c128[1000],
_pred=bool[],
_tag=<object object at 0x2e4dc0850>,
_makes_false_steps=False
),
save_index=i64[]
),
dense_ts=None,
dense_infos=None,
dense_save_index=None
)
Right! Complex numbers are only kind-of supported right now. The main blocker is how we've been waiting on an XLA bug fix, although I'm happy to say that this has recently been fixed.
Regardless, right now, Diffrax has pretty weak support for complex numbers. You should be able to use them within your vector field, as long as you decompose them into real and imaginary parts whenever you interface with Diffrax.
@Ricky5389 can you try the latest dev? It should work on it.
Yes it works, Thank you