diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

[Feature request] differential algebraic equations

Open mfkasim1 opened this issue 3 years ago • 5 comments

Hi again, I have found this repository is also useful to me. One thing that would be ideal to have is the differential algebraic equations (DAE) solver (http://www.scholarpedia.org/article/Differential-algebraic_equations), at least the semi-explicit form. Is there a plan to add this in diffrax?

mfkasim1 avatar Feb 14 '22 10:02 mfkasim1

In principle, absolutely -- this would be great. Realistically this is something I'm unlikely to add myself in the near future, however. But if this is important to you then I'd be very happy to work with you to put together a PR adding support for this kind of thing.

At least for semi-explicit DAEs, most of the support needed from Diffrax should already be present. In fact I think it should be possible to make things worth through the existing AbstractTerm / AbstractSolver interfaces, without needing to make any internal changes to Diffrax.

Untested and hastily thrown together, a first approach might look something like this.

We'll solve the semi-explicit DAE

dy(t) = f(t, y(t), z(t), args) dx(t)
0 = g(t, y(t), z(t), args)

by first solving

y_{n+1} = ODESolve(f, t_n, t_{n+1}, y_n, z_n, args, x|[t_n, t_{n+1}])

with a single step of an existing solver, and then solving

0 = g(t_{n+1}, y_{n+1}, z_{n+1}, args)

via a nonlinear solver.

(If the dx(t) looks unfamiliar to you (not that many folks are familiar with controlled differential equations) then feel free to think of it as dx(t) = dt and rearrange the above expression into the more normal dy/dt form. The above is just the general version that works for SDEs etc. as well.)

For the end user, the code would be used like so.

def vector_field(t, y, z__args):
    z, args = z__args
    return dy_dt

def constraint(t, y, z__args):
    z, args = z__args
    return value_that_should_be_zero

term = ConstrainedTerm(ODETerm(vector_field), constraint)
solver = SemiExplicitConstrainedSolver(Kvaerno5())
diffeqsolve(term, solver, ...)  # as normal

And finally the implementation is as follows.

# First just wrap together a term and a constraint
# We arrange it so that the `z` component of the DAE is passed through `args` to the
# user-specified vector field and constraint.
class ConstrainedTerm(AbstractTerm):
    term: AbstractTerm
    constraint: Callable[[Scalar, PyTree, Tuple[PyTree, PyTree]], PyTree]

    def vf(self, t, y, args):
        y, z = y
        return self.term.vf(t, y, (z, args))

    def contr(self, t0, t1):
        return self.term.contr(t0, t1)

    def prod(self, vf, control):
        return self.term.prod(vf, control)

    def vf_prod(self, t, y, args, control):
        y, z = y
        return self.term.vf_prod(t, y, (z, args), control)

    def constr(self, t, y, args):
        y, z = y
        return self.constraint(t, y, (z, args))

def _implicit_relation(z1, nonlinear_solve_args):
    constraint_fn, t1, y1, args = nonlinear_solve_args
    return constraint_fn(t1, (y1, z1), args)

# AbstractWrappedSolver gives us access to self.solver (an ODE/SDE/etc. solver)
# AbstractImplicitSolver gives us access to self.nonlinear_solver
class SemiExplicitConstrainedSolver(AbstractWrappedSolver, AbstractImplicitSolver):
    term_structure: jax.tree_structure(0)
    interpolation_cls = LocalLinearInterpolation

    def order(self, terms):
        return self.solver.order(terms)

    def strong_order(self, terms):
        return self.solver.strong_order(terms)

    def error_order(self, terms):
        return self.solver.error_order(terms)

    def init(self, terms, t0, t1, y0, args):
        return self.solver.init(terms, t0, t1, y0, args)

    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        y0, z0 = y0
        y1, y_error, _, solver_state, result = self.solver.step(terms, t0, t1, y0, args, solver_state, made_jump)
        jac = self.nonlinear_solver.jac(_implicit_relation, (terms.constr, t1, y1, args))
        nonlinear_sol = self.nonlinear_solver(_implicit_relation, (terms.constr, t1, y1, args), jac)
        z1 = nonlinear_sol.root
        z_error = jax.tree_map(jnp.zeros_like, z1)
        dense_info = dict(y0=(y0, z0), y1=(y1, z1))
        result = jnp.maximum(result, nonlinear_sol.result)
        return (y1, z1), (y_error, z_error), dense_info, solver_state, result

Various comments on this implementation:

  • I've not looked too closely at the details of solving a DAE. I don't know how effective/stable/etc. the above approach is numerically.
  • The output via SaveAt(ts=...) or SaveAt(dense=...) uses the specified interpolation_cls, which for simplicity here is just linear interpolation. A more serious implementation here would find a way to use the interpolation_cls of the wrapped solver.
  • It estimates that zero error is made in obtaining the solution for z (hence the z_error = ... line). For context these error estimates are used to handle adaptive time stepping.
  • The nonlinear solver is done via the chord method. If we removed the jac = ... line (and just passed jac = None instead) then it would be via Newton's method.

patrick-kidger avatar Feb 14 '22 13:02 patrick-kidger

Thanks for your answer! I'll try to digest it first.

mfkasim1 avatar Feb 15 '22 19:02 mfkasim1

@patrick-kidger Thanks for the great library. I can make this PR if you'd like to help. I'm in need of this functionality.

YangyangFu avatar Nov 29 '22 06:11 YangyangFu

Sure, I'd be happy to see what you come up with.

I'm imagining an API looking something like

        constraints=<pytree of constraint functions>,
        z0=<value for extra state>,
        ... # any other options, e.g. choice of nonlinear solver for the projection step

FWIW we're currently implementing delay diffeqs over in #169, and I imagine this will have some overlap. (E.g. binding the extra z in a similar way to the way history is bound over there, using VectorFieldWrapper.)

patrick-kidger avatar Nov 29 '22 14:11 patrick-kidger

Hi, I'm trying to define a new Solver Wrapper just like it is shown in the example above for SemiExplicitConstrainedSolver. The class is something like that:

class CustomSolver(AbstractWrappedSolver, AbstractImplicitSolver):
	tree_structure = jax.tree_util.tree_structure(0)
	interpolation_cls = ThirdOrderHermitePolynomialInterpolation.from_k # Like the one used for Kvaerno5
	def order() ...
	def strong_order() ..
	def error_order() ...
	def init() ...
	def step() ...
	def func() ...

However, I get the following error:

KeyError                                  Traceback (most recent call last)
File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_better_abstract.py:239, in dataclass.<locals>.make_dataclass(cls)
    238 try:
--> 239     annotations = cls.__dict__["__annotations__"]
    240 except KeyError:

KeyError: '__annotations__'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
Cell In[12], line 1
----> 1 class ExplicitDAESolver(AbstractWrappedSolver, AbstractImplicitSolver):
      2     # Use the interpolation scheme from the Kvaerno5 solver
      3     # since we will use that and no other solvers.
      4     term_structure = jax.tree_util.tree_structure(0)
      5     interpolation_cls = ThirdOrderHermitePolynomialInterpolation.from_k

File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_module.py:107, in _ModuleMeta.__new__(mcs, name, bases, dict_)
    105 if _init:
    106     init_doc = cls.__init__.__doc__
--> 107 cls = dataclass(eq=False, repr=False, frozen=True, init=_init)(
    108     cls  # pyright: ignore
    109 )
    110 if _init:
    111     cls.__init__.__doc__ = init_doc  # pyright: ignore

File ~/sbml2gpu/venv/lib/python3.10/site-packages/equinox/_better_abstract.py:241, in dataclass.<locals>.make_dataclass(cls)
    239     annotations = cls.__dict__["__annotations__"]
    240 except KeyError:
--> 241     cls = dataclasses.dataclass(**kwargs)(cls)
    242 else:
    243     new_annotations = dict(annotations)

File /usr/lib/python3.10/dataclasses.py:1176, in dataclass.<locals>.wrap(cls)
   1175 def wrap(cls):
-> 1176     return _process_class(cls, init, repr, eq, order, unsafe_hash,
   1177                           frozen, match_args, kw_only, slots)

File /usr/lib/python3.10/dataclasses.py:1025, in _process_class(cls, init, repr, eq, order, unsafe_hash, frozen, match_args, kw_only, slots)
   1020 if init:
   1021     # Does this class have a post-init function?
   1022     has_post_init = hasattr(cls, _POST_INIT_NAME)
   1024     _set_new_attribute(cls, '__init__',
-> 1025                        _init_fn(all_init_fields,
   1026                                 std_init_fields,
   1027                                 kw_only_init_fields,
   1028                                 frozen,
   1029                                 has_post_init,
   1030                                 # The name to use for the "self"
   1031                                 # param in __init__.  Use "self"
   1032                                 # if possible.
   1033                                 '__dataclass_self__' if 'self' in fields
   1034                                         else 'self',
   1035                                 globals,
   1036                                 slots,
   1037                       ))
   1039 # Get the fields as a list, and include only real fields.  This is
   1040 # used in all of the following methods.
   1041 field_list = [f for f in fields.values() if f._field_type is _FIELD]

File /usr/lib/python3.10/dataclasses.py:546, in _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init, self_name, globals, slots)
    544             seen_default = True
    545         elif seen_default:
--> 546             raise TypeError(f'non-default argument {f.name!r} '
    547                             'follows default argument')
    549 locals = {f'_type_{f.name}': f.type for f in fields}
    550 locals.update({
    551     'MISSING': MISSING,
    553 })

TypeError: non-default argument 'solver' follows default argument

If I remove the parent AbstractImplicitSolver i can go on without any error, until I try to simulate and at this point it gives me the error that CustomSolver object has no attribute 'nonlinear_solver' (because I'm using Kvaerno5 as wrapped solver). At this point the problem is that I'm quite stuck because I need the AbstractImplicitSolver in Custom Solver, but this gives me the error.

Any suggestions?

lmriccardo avatar Jun 01 '23 18:06 lmriccardo