diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Hypersolvers

Open patrick-kidger opened this issue 2 years ago • 3 comments

Hypersolvers are learnt solvers, as per https://arxiv.org/abs/2007.09601. See also their implementation in https://github.com/DiffEqML/torchdyn

patrick-kidger avatar Aug 18 '21 17:08 patrick-kidger

I wonder if you want to make a tutorial example for this type of solver. The implemetation of hypersolver in diffrax is quite straightforward

class HyperEuler(diffrax.Euler):

    hypernet: diffrax.AbstractTerm

    def __init__(self, hypernet: diffrax.AbstractTerm):
        self.hypernet = hypernet
    
    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        y1, _, dense_info, _, result = super().step(terms, t0, t1, y0, args, solver_state, made_jump)
        control = terms.contr(t0, t1)
        control = jax.tree_map(operator.pow, control, self.order(terms) + 1)
        y1 = jax.tree_map(operator.add, y1, self.hypernet.vf_prod(t=t0, y=y0, args=None, control=control))
        return y1, None, dense_info, None, result 

I made a script trying to replicate this tutorial of torchdyn as following

https://colab.research.google.com/drive/1c9_AM-5NiLox1Do-DpiVJ-KXQKZDngvu?usp=sharing

anh-tong avatar Jun 16 '22 04:06 anh-tong

Hmm, I'd need to think about the cleanest abstraction (it feels a bit weird to me to pass a term into a solver at init time) but something like your approach makes sense.

Perhaps also make the base solver (here Euler) an input as well, so that there's a single Hypersolver class rather than a HyperEuler etc. For this then AbstractWrappedSolver may be of help.

(Also we could probably forward the error estimate etc through safely?)

I'd be happy to accept a PR on this, as either something in the base library or as one of the examples.

patrick-kidger avatar Jun 16 '22 19:06 patrick-kidger

Got it! AbstractWrappedSolver is the way to go. I will gladly make a PR on this.

anh-tong avatar Jun 17 '22 07:06 anh-tong