diffrax
diffrax copied to clipboard
[Feature request] differential algebraic equations
Hi again, I have found this repository is also useful to me. One thing that would be ideal to have is the differential algebraic equations (DAE) solver (http://www.scholarpedia.org/article/Differential-algebraic_equations), at least the semi-explicit form. Is there a plan to add this in diffrax?
In principle, absolutely -- this would be great. Realistically this is something I'm unlikely to add myself in the near future, however. But if this is important to you then I'd be very happy to work with you to put together a PR adding support for this kind of thing.
At least for semi-explicit DAEs, most of the support needed from Diffrax should already be present. In fact I think it should be possible to make things worth through the existing AbstractTerm
/ AbstractSolver
interfaces, without needing to make any internal changes to Diffrax.
Untested and hastily thrown together, a first approach might look something like this.
We'll solve the semi-explicit DAE
dy(t) = f(t, y(t), z(t), args) dx(t)
0 = g(t, y(t), z(t), args)
by first solving
y_{n+1} = ODESolve(f, t_n, t_{n+1}, y_n, z_n, args, x|[t_n, t_{n+1}])
with a single step of an existing solver, and then solving
0 = g(t_{n+1}, y_{n+1}, z_{n+1}, args)
via a nonlinear solver.
(If the dx(t)
looks unfamiliar to you (not that many folks are familiar with controlled differential equations) then feel free to think of it as dx(t) = dt
and rearrange the above expression into the more normal dy/dt
form. The above is just the general version that works for SDEs etc. as well.)
For the end user, the code would be used like so.
def vector_field(t, y, z__args):
z, args = z__args
...
return dy_dt
def constraint(t, y, z__args):
z, args = z__args
...
return value_that_should_be_zero
term = ConstrainedTerm(ODETerm(vector_field), constraint)
solver = SemiExplicitConstrainedSolver(Kvaerno5())
diffeqsolve(term, solver, ...) # as normal
And finally the implementation is as follows.
# First just wrap together a term and a constraint
#
# We arrange it so that the `z` component of the DAE is passed through `args` to the
# user-specified vector field and constraint.
class ConstrainedTerm(AbstractTerm):
term: AbstractTerm
constraint: Callable[[Scalar, PyTree, Tuple[PyTree, PyTree]], PyTree]
def vf(self, t, y, args):
y, z = y
return self.term.vf(t, y, (z, args))
def contr(self, t0, t1):
return self.term.contr(t0, t1)
def prod(self, vf, control):
return self.term.prod(vf, control)
def vf_prod(self, t, y, args, control):
y, z = y
return self.term.vf_prod(t, y, (z, args), control)
def constr(self, t, y, args):
y, z = y
return self.constraint(t, y, (z, args))
def _implicit_relation(z1, nonlinear_solve_args):
constraint_fn, t1, y1, args = nonlinear_solve_args
return constraint_fn(t1, (y1, z1), args)
# AbstractWrappedSolver gives us access to self.solver (an ODE/SDE/etc. solver)
# AbstractImplicitSolver gives us access to self.nonlinear_solver
class SemiExplicitConstrainedSolver(AbstractWrappedSolver, AbstractImplicitSolver):
term_structure: jax.tree_structure(0)
interpolation_cls = LocalLinearInterpolation
def order(self, terms):
return self.solver.order(terms)
def strong_order(self, terms):
return self.solver.strong_order(terms)
def error_order(self, terms):
return self.solver.error_order(terms)
def init(self, terms, t0, t1, y0, args):
return self.solver.init(terms, t0, t1, y0, args)
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
y0, z0 = y0
y1, y_error, _, solver_state, result = self.solver.step(terms, t0, t1, y0, args, solver_state, made_jump)
jac = self.nonlinear_solver.jac(_implicit_relation, (terms.constr, t1, y1, args))
nonlinear_sol = self.nonlinear_solver(_implicit_relation, (terms.constr, t1, y1, args), jac)
z1 = nonlinear_sol.root
z_error = jax.tree_map(jnp.zeros_like, z1)
dense_info = dict(y0=(y0, z0), y1=(y1, z1))
result = jnp.maximum(result, nonlinear_sol.result)
return (y1, z1), (y_error, z_error), dense_info, solver_state, result
Various comments on this implementation:
- I've not looked too closely at the details of solving a DAE. I don't know how effective/stable/etc. the above approach is numerically.
- The output via
SaveAt(ts=...)
orSaveAt(dense=...)
uses the specifiedinterpolation_cls
, which for simplicity here is just linear interpolation. A more serious implementation here would find a way to use theinterpolation_cls
of the wrapped solver. - It estimates that zero error is made in obtaining the solution for
z
(hence thez_error = ...
line). For context these error estimates are used to handle adaptive time stepping. - The nonlinear solver is done via the chord method. If we removed the
jac = ...
line (and just passedjac = None
instead) then it would be via Newton's method.
Thanks for your answer! I'll try to digest it first.
@patrick-kidger Thanks for the great library. I can make this PR if you'd like to help. I'm in need of this functionality.
Sure, I'd be happy to see what you come up with.
I'm imagining an API looking something like
diffeqsolve(
...
constraints=Constraints(
constraints=<pytree of constraint functions>,
z0=<value for extra state>,
... # any other options, e.g. choice of nonlinear solver for the projection step
)
FWIW we're currently implementing delay diffeqs over in #169, and I imagine this will have some overlap. (E.g. binding the extra z
in a similar way to the way history is bound over there, using VectorFieldWrapper
.)
Hi, I'm trying to define a new Solver Wrapper just like it is shown in the example above for SemiExplicitConstrainedSolver
. The class is something like that:
class CustomSolver(AbstractWrappedSolver, AbstractImplicitSolver):
tree_structure = jax.tree_util.tree_structure(0)
interpolation_cls = ThirdOrderHermitePolynomialInterpolation.from_k # Like the one used for Kvaerno5
def order() ...
def strong_order() ..
def error_order() ...
def init() ...
def step() ...
def func() ...
However, I get the following error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_better_abstract.py:239, in dataclass.<locals>.make_dataclass(cls)
238 try:
--> 239 annotations = cls.__dict__["__annotations__"]
240 except KeyError:
KeyError: '__annotations__'
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
Cell In[12], line 1
----> 1 class ExplicitDAESolver(AbstractWrappedSolver, AbstractImplicitSolver):
2 # Use the interpolation scheme from the Kvaerno5 solver
3 # since we will use that and no other solvers.
4 term_structure = jax.tree_util.tree_structure(0)
5 interpolation_cls = ThirdOrderHermitePolynomialInterpolation.from_k
File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_module.py:107, in _ModuleMeta.__new__(mcs, name, bases, dict_)
105 if _init:
106 init_doc = cls.__init__.__doc__
--> 107 cls = dataclass(eq=False, repr=False, frozen=True, init=_init)(
108 cls # pyright: ignore
109 )
110 if _init:
111 cls.__init__.__doc__ = init_doc # pyright: ignore
File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_better_abstract.py:241, in dataclass.<locals>.make_dataclass(cls)
239 annotations = cls.__dict__["__annotations__"]
240 except KeyError:
--> 241 cls = dataclasses.dataclass(**kwargs)(cls)
242 else:
243 new_annotations = dict(annotations)
File /usr/lib/python3.10/dataclasses.py:1176, in dataclass.<locals>.wrap(cls)
1175 def wrap(cls):
-> 1176 return _process_class(cls, init, repr, eq, order, unsafe_hash,
1177 frozen, match_args, kw_only, slots)
File /usr/lib/python3.10/dataclasses.py:1025, in _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)
1020 if init:
1021 # Does this class have a post-init function?
1022 has_post_init = hasattr(cls, _POST_INIT_NAME)
1024 _set_new_attribute(cls, '__init__',
-> 1025 _init_fn(all_init_fields,
1026 std_init_fields,
1027 kw_only_init_fields,
1028 frozen,
1029 has_post_init,
1030 # The name to use for the "self"
1031 # param in __init__. Use "self"
1032 # if possible.
1033 '__dataclass_self__' if 'self' in fields
1034 else 'self',
1035 globals,
1036 slots,
1037 ))
1039 # Get the fields as a list, and include only real fields. This is
1040 # used in all of the following methods.
1041 field_list = [f for f in fields.values() if f._field_type is _FIELD]
File /usr/lib/python3.10/dataclasses.py:546, in _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init, self_name, globals, slots)
544 seen_default = True
545 elif seen_default:
--> 546 raise TypeError(f'non-default argument {f.name!r} '
547 'follows default argument')
549 locals = {f'_type_{f.name}': f.type for f in fields}
550 locals.update({
551 'MISSING': MISSING,
552 '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
553 })
TypeError: non-default argument 'solver' follows default argument
If I remove the parent AbstractImplicitSolver
i can go on without any error, until I try to simulate and at this point it gives me the error that CustomSolver object has no attribute 'nonlinear_solver'
(because I'm using Kvaerno5 as wrapped solver). At this point the problem is that I'm quite stuck because I need the AbstractImplicitSolver
in Custom Solver, but this gives me the error.
Any suggestions?