diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Coupled SDE System Implementation

Open aspannaus opened this issue 1 year ago • 7 comments

Hi all,

thanks for the great library. I'm having an issue implementing a coupled system of SDEs. I'm getting an ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`) error. The system is:

\begin{aligned}
		\frac{\mathrm{d} S(t)}{\mathrm{d} t} &= -\beta(t)S(t)\frac{I(t)}{N} \mathrm{d} t, \\
		\frac{\mathrm{d} I(t)}{\mathrm{d}t} &= (\beta(t)S(t)\frac{I(t)}{N} - \gamma(t) I(t)) \mathrm{d} t,\\
		\frac{\mathrm{d} R(t)}{\mathrm{d}t} &= \gamma(t) I(t)\, \mathrm{d} t,\\
		\frac{\mathrm{d} \log\beta(t)}{\mathrm{d}t} &= w_3\mathrm{d} B_w(t),\\
        \frac{\mathrm{d} \log\gamma(t)}{\mathrm{d}t} &= u_3 \mathrm{d}B_u(t)
\end{aligned}

The code is


import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

import diffrax

def sde_drift(t, y, args):
    N, _ = args
    beta = jnp.exp(y[3])
    gamma = jnp.exp(y[4])
    dS = -(beta * y[0] * y[1]) / N
    dI = (beta * y[0] * y[1]) / N - y[1] * gamma
    dR = y[1] * gamma
    # only diffusion, no drift
    dbeta = 0.0  # jnp.array([0.0])
    dgamma = 0.0  # jnp.array([0.0])
    dy = jnp.array([dS, dI, dR, dbeta, dgamma])

    return dy

def sde_diffusion(t, y, args):
    _, sigma_1 = args
    y1, y2, y3, y4, y5 = y
    diagonal = jnp.array([0.0, 0.0, 0.0, sigma_1 * y4, sigma_1 * y5])
    return diagonal 


def sde():

    t0 = 0
    t1 = 100
    dt0 = 0.1
    y0 = jnp.array([3990.0, 10.0, 0.01, jnp.log(0.25), jnp.log(0.05)])
    args = (4000.0, 0.2)

    bm = diffrax.VirtualBrownianTree(t0, t1, tol=1e-2, shape=(5,), key=jr.PRNGKey(42))
    terms = diffrax.MultiTerm(diffrax.ODETerm(sde_drift), diffrax.ControlTerm(sde_diffusion, bm))
    solver = diffrax.SEA()
    saveat = diffrax.SaveAt(dense=True)

    print(type(terms))

    sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=args, saveat=saveat)
    print(sol)

Printing the type of terms yields 'diffrax._term.MultiTerm, so I'm not entirely sure where to look. What can you suggest to look at?

Thanks in advance.

aspannaus avatar Jul 10 '24 18:07 aspannaus

The classic (https://github.com/patrick-kidger/diffrax/issues/446#issuecomment-2187405940) strikes once again 😉

It seems like there are a few errors here. First, you return a diagonal, but control term is for full matrices by default, so you need to fix that (with a DiagonalOperator). Second, SEA requires a SpaceTimeLevy area (this should go in the solver docs imo). Finally, SEA requires additive noise (i.e. g is not a function of x) so you can't use this solver with that noise term.

Using all three tricks you get something that works and looks like:


import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt

import diffrax
import lineax as lx

def sde_drift(t, y, args):
    N, _ = args
    beta = jnp.exp(y[3])
    gamma = jnp.exp(y[4])
    dS = -(beta * y[0] * y[1]) / N
    dI = (beta * y[0] * y[1]) / N - y[1] * gamma
    dR = y[1] * gamma
    # only diffusion, no drift
    dbeta = 0.0  # jnp.array([0.0])
    dgamma = 0.0  # jnp.array([0.0])
    dy = jnp.array([dS, dI, dR, dbeta, dgamma])

    return dy

def sde_diffusion(t, y, args):
    _, sigma_1 = args
    y1, y2, y3, y4, y5 = y
    diagonal = jnp.array([0.0, 0.0, 0.0, sigma_1 * y4, sigma_1 * y5])
    return lx.DiagonalLinearOperator(diagonal) 


def sde():

    t0 = 0
    t1 = 100
    dt0 = 0.1
    y0 = jnp.array([3990.0, 10.0, 0.01, jnp.log(0.25), jnp.log(0.05)])
    args = (4000.0, 0.2)

    bm = diffrax.VirtualBrownianTree(t0, t1, tol=1e-2, shape=(5,), key=jr.PRNGKey(42), levy_area=diffrax.SpaceTimeLevyArea)
    terms = diffrax.MultiTerm(diffrax.ODETerm(sde_drift), diffrax.ControlTerm(sde_diffusion, bm))
    solver = diffrax.GeneralShARK()
    saveat = diffrax.SaveAt(dense=True)

    print(type(terms))

    sol = diffrax.diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=args, saveat=saveat)
    print(sol)

sde()

lockwo avatar Jul 11 '24 02:07 lockwo

Thanks for the reply; I must have missed some of the points about the solver you make in the docs.

Trying the code you suggested, I get the error ValueError: Custom node type mismatch: expected type: <class 'lineax._operator.DiagonalLinearOperator'>, value: Traced<ShapedArray(float32[5])>with<DynamicJaxprTrace(level=2/0)>. I had tried this previously without success, but perhaps it is correct and there's something behind the scenes happening?

For completeness, here's the library versions I'm using:

  • jax: 0.4.30
  • diffrax: 0.5.1
  • lineax: 0.0.5
  • equinox: 0.11.4

Thanks again for the assistance.

aspannaus avatar Jul 11 '24 13:07 aspannaus

Yes, I was using diffrax 0.6.0

lockwo avatar Jul 11 '24 16:07 lockwo

That was it, thanks again!

aspannaus avatar Jul 11 '24 16:07 aspannaus

Hi, I had the same issue as above. `

ValueError: Custom node type mismatch: expected type: <class `'lineax._operator.DiagonalLinearOperator'>``

I updated all the packages to the versions above abd I get the error:

AttributeError: module 'opt_einsum' has no attribute 'paths'

Do you have any ideas?

SoerenNagel avatar Aug 15 '24 12:08 SoerenNagel

What versions are you using?

lockwo avatar Aug 15 '24 21:08 lockwo

Hi Owen, i fixed the issue by setting up a new conda environment and made sure jax, jaxlib,equinox, lineax and diffrax through pip and not conda (where diffrax 0.6.0 is not yet availble). I don't really know what the underlying issue was. But thanks anyway.

SoerenNagel avatar Aug 15 '24 21:08 SoerenNagel