diffrax
diffrax copied to clipboard
Hypersolvers
Hypersolvers are learnt solvers, as per https://arxiv.org/abs/2007.09601. See also their implementation in https://github.com/DiffEqML/torchdyn
I wonder if you want to make a tutorial example for this type of solver. The implemetation of hypersolver in diffrax
is quite straightforward
class HyperEuler(diffrax.Euler):
hypernet: diffrax.AbstractTerm
def __init__(self, hypernet: diffrax.AbstractTerm):
self.hypernet = hypernet
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
y1, _, dense_info, _, result = super().step(terms, t0, t1, y0, args, solver_state, made_jump)
control = terms.contr(t0, t1)
control = jax.tree_map(operator.pow, control, self.order(terms) + 1)
y1 = jax.tree_map(operator.add, y1, self.hypernet.vf_prod(t=t0, y=y0, args=None, control=control))
return y1, None, dense_info, None, result
I made a script trying to replicate this tutorial of torchdyn
as following
https://colab.research.google.com/drive/1c9_AM-5NiLox1Do-DpiVJ-KXQKZDngvu?usp=sharing
Hmm, I'd need to think about the cleanest abstraction (it feels a bit weird to me to pass a term into a solver at init time) but something like your approach makes sense.
Perhaps also make the base solver (here Euler
) an input as well, so that there's a single Hypersolver
class rather than a HyperEuler
etc. For this then AbstractWrappedSolver
may be of help.
(Also we could probably forward the error estimate etc through safely?)
I'd be happy to accept a PR on this, as either something in the base library or as one of the examples.
Got it! AbstractWrappedSolver
is the way to go. I will gladly make a PR on this.