diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Complex input in diffeqsolve with PIDController

Open Ricky5389 opened this issue 1 year ago • 3 comments

Hi, I encountered an issue while using the PIDController. When using complex input in diffeqsolve, I cannot change the coefficient of the PIDController. The default coefficients works fine but when I try to change I get an error leading me to believe that the time variable is converted to complex type somewhere. Here is a MWE to reproduce the error

# %% ==========================================================================
# Imports
# =============================================================================
import jax.numpy as jnp
import diffrax as dx
from jaxtyping import  Scalar

# %% ==========================================================================
# Smallest working example, solving a complex ODE with diffrax
# ODE to solve is dy/dt = iy with y a complex number
# Setting up the ODE
# =============================================================================
def vector_field(t: Scalar, y: Scalar, *args):
    dotY = 1j * y
    return dotY

tsave = jnp.linspace(0.0, 10.0, 1000)
sim_time = tsave[-1]

solver = dx.Tsit5()
saveat = dx.SaveAt(ts=tsave)
y0 = 1.0j
# %% ==========================================================================
# Solve using the default PIDController coefficients
# =============================================================================
stepsize_controller = dx.PIDController(rtol=1e-6, atol=1e-6, pcoeff=0, icoeff=1, dcoeff=0.0)

term = dx.ODETerm(vector_field = vector_field)
res_dx = dx.diffeqsolve(term, solver, t0=0.0, t1=sim_time, dt0=0.01, y0=y0, saveat=saveat,
        stepsize_controller=stepsize_controller)

# %% ==========================================================================
# Solve using the non-zero P coefficient
# =============================================================================
stepsize_controller = dx.PIDController(rtol=1e-6, atol=1e-6, pcoeff=0.3, icoeff=0.3, dcoeff=0.0)

term = dx.ODETerm(vector_field = vector_field)
res_dx = dx.diffeqsolve(term, solver, t0=0.0, t1=sim_time, dt0=0.01, y0=y0, saveat=saveat,
        stepsize_controller=stepsize_controller)


## Changing the coefficient in the PID controller seems to raise the error
ValueError: `body_fun` must have the same input and output structure. Difference is:
  State(
    y=c128[],
    tprev=f64[],
-   tnext=f64[],
+   tnext=c128[],
    made_jump=bool[],
    solver_state=(bool[], c128[]),
-   controller_state=(bool[], bool[], f64[], c128[], c128[]),
+   controller_state=(bool[], bool[], c128[], c128[], c128[]),
    result=EnumerationItem(
      _value=i32[],
      _enumeration=<class 'diffrax._solution.RESULTS'>
    ),
    num_steps=i64[],
    num_accepted_steps=i64[],
    num_rejected_steps=i64[],
    save_state=SaveState(
      saveat_ts_index=i64[],
      ts=_Buffer(
        _array=f64[1000],
        _pred=bool[],
        _tag=<object object at 0x2e4dc0850>,
        _makes_false_steps=False
      ),
      ys=_Buffer(
        _array=c128[1000],
        _pred=bool[],
        _tag=<object object at 0x2e4dc0850>,
        _makes_false_steps=False
      ),
      save_index=i64[]
    ),
    dense_ts=None,
    dense_infos=None,
    dense_save_index=None
  )

Ricky5389 avatar Mar 15 '24 19:03 Ricky5389

Right! Complex numbers are only kind-of supported right now. The main blocker is how we've been waiting on an XLA bug fix, although I'm happy to say that this has recently been fixed.

Regardless, right now, Diffrax has pretty weak support for complex numbers. You should be able to use them within your vector field, as long as you decompose them into real and imaginary parts whenever you interface with Diffrax.

patrick-kidger avatar Mar 15 '24 22:03 patrick-kidger

@Ricky5389 can you try the latest dev? It should work on it.

Randl avatar Apr 22 '24 19:04 Randl

Yes it works, Thank you

Ricky5389 avatar May 17 '24 20:05 Ricky5389