diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Obscure error for passing EulerHeun instead of EulerHeun()

Open rdaems opened this issue 2 months ago • 3 comments

Hi,

I've wasted some time because I made a small mistake, and the error was quite obscure. I passed EulerHeun instead of EulerHeun() as the solver argument. Minimal example:

import jax.random as jr
from diffrax import diffeqsolve, ControlTerm, EulerHeun, MultiTerm, ODETerm, SaveAt, VirtualBrownianTree

t0, t1 = 0, 3
drift = lambda t, y, args: -y
diffusion = lambda t, y, args: 0.1 * t
brownian_motion = VirtualBrownianTree(t0, t1, tol=1e-3, shape=(), key=jr.PRNGKey(0))
terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, brownian_motion))
solver = EulerHeun
saveat = SaveAt(dense=True)

sol = diffeqsolve(terms, solver, t0, t1, dt0=0.05, y0=1.0, saveat=saveat)
print(sol.evaluate(1.1))  # DeviceArray(0.89436394)

This returns

--> 137 assert type(term_contr_kwargs) is tuple

For the Euler solver, the error is different:

TypeError: functools.partial() argument after ** must be a mapping, not property

Maybe this can be improved? I can take a look later if I have some time.

rdaems avatar Nov 07 '25 09:11 rdaems

Hi! Thanks for the report.

A simple isinstance(solver, dfx.AbstractSolver) should detect if the solver hasn't been instantiated at compile time. Given that this would be a one-liner inside diffeqsolve, we can either easily implement it, or there is a good reason we haven't already. I can take a look to make up my mind which it is! @lockwo do you have insight here?

Type-checking should also catch this, in case you need another tool in your arsenal to find errata like this faster. This would look like

import diffrax as dfx 

from jaxtyping import jaxtyped
from beartype import beartype as typechecker

# Other code

sol = jaxtyped(typechecker=typechecker)(dfx.diffeqsolve)(...)

johannahaffner avatar Nov 07 '25 13:11 johannahaffner

They are both failing the term check, differently I believe because EulerHeun.term_structure is a MultiTerm, and thus goes into a different branch of the check and term_contr_kwargs expects a tuple but gets a property. Euler.term_structure is a AbstractTerm, so it fails on the other branch trying to get the kwargs but also getting a property (just in a different way). A simple is instance check would catch it (probably worth it, there are an infinite number of ways the user input is invalid, and trying to catch them all is tilting at windmills, but reasonable mistakes are probably good)

lockwo avatar Nov 09 '25 06:11 lockwo

Thanks for taking a look! FWIW I think this is a mistake I've made myself as well in the past, but over in Optimistix.

I've implemented this in https://github.com/patrick-kidger/diffrax/pull/706 (not in one line, though).

johannahaffner avatar Nov 09 '25 10:11 johannahaffner