diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Implementation of multirate methods in diffrax

Open BenjaminDAnjou opened this issue 6 months ago • 2 comments

Has there been discussion of implementing multirate methods in diffrax? See Sec. 2.4.3. of the linked article, Fig. 4.

By this, I mean methods tailored to the case where ODETerm has components that have an explicit separation in timescales, such that different time steps can be used internally when solving for the separated components.

My understanding is that the IMEX methods partly address this but wash out the fast component in favor of an accurate slow component (adiabatic elimination). I am thinking of multirate methods which accurately capture both the fast and slow components.

If my above understanding is not correct, feel free to educate me.

I think multirate methods could be a valuable tool in a wide array of physics problems. I would like to know how other people feel about this.

BenjaminDAnjou avatar Jul 02 '25 11:07 BenjaminDAnjou

Hi there! This is something which can be done pretty straightforwardly with a custom solver. For example supposing your slow dynamics are independent of your fast dynamics, then (untested) it would look something like this:

class TwoRateSolver(diffrax.AbstractSolver):
    slow_solver: diffrax.AbstractSolver
    fast_diffeqsolve: dict[str, Any]

    term_structure = (diffrax.AbstractTerm, diffrax.AbstractTerm)
    interpolation_cls = diffrax.LocalLinearInterpolation  # Or something better than this

    def step(terms, t0, t1, y0, args, solver_state, made_jump):
        slow_term, fast_term = terms
        slow_y0, fast_y0 = y0
        slow_state, fast_state = solver_state

        slow_y1, _, _, new_slow_state, slow_result = self.slow_solver.step(slow_term, t0, t1, slow_y0, args, slow_solver_state, made_jump)
        fast_sol = diffeqsolve(fast_term, t0=t0, t1=t1, y0=y0, args=args, solver_state=fast_state, made_jump=made_jump, **self.fast_diffeqsolve)
        [fast_y1] = fast_sol.ys
        new_fast_state = fast_sol.solver_state

        y1 = (slow_y1, fast_y1)
        dense_info = dict(y0=y0, y1=y1)
        new_solver_state = (new_slow_state, new_fast_state)
        result = diffrax.RESULTS.where(slow_result == diffrax.RESULTS.successful, fast_result, slow_result)
        return y1, None, dense_info, new_solver_state, result

    ...  # fill in the other methods

I could see this being something we support directly in Diffrax, if there is enough interest. :)

patrick-kidger avatar Jul 06 '25 09:07 patrick-kidger

Thanks for the suggestion. I have a bit of a learning curve before I can get this to work, but I'll try to get started and get a MWE.

In the meantime, I'd like to point out that this is not wholly conceptually unrelated to the issue discussed in #462 and #595. I thought I'd make that connection here, just in case it helps make future implementations more streamlined and/or general.

BenjaminDAnjou avatar Sep 09 '25 23:09 BenjaminDAnjou