Implementation of multirate methods in diffrax
Has there been discussion of implementing multirate methods in diffrax? See Sec. 2.4.3. of the linked article, Fig. 4.
By this, I mean methods tailored to the case where ODETerm has components that have an explicit separation in timescales, such that different time steps can be used internally when solving for the separated components.
My understanding is that the IMEX methods partly address this but wash out the fast component in favor of an accurate slow component (adiabatic elimination). I am thinking of multirate methods which accurately capture both the fast and slow components.
If my above understanding is not correct, feel free to educate me.
I think multirate methods could be a valuable tool in a wide array of physics problems. I would like to know how other people feel about this.
Hi there! This is something which can be done pretty straightforwardly with a custom solver. For example supposing your slow dynamics are independent of your fast dynamics, then (untested) it would look something like this:
class TwoRateSolver(diffrax.AbstractSolver):
slow_solver: diffrax.AbstractSolver
fast_diffeqsolve: dict[str, Any]
term_structure = (diffrax.AbstractTerm, diffrax.AbstractTerm)
interpolation_cls = diffrax.LocalLinearInterpolation # Or something better than this
def step(terms, t0, t1, y0, args, solver_state, made_jump):
slow_term, fast_term = terms
slow_y0, fast_y0 = y0
slow_state, fast_state = solver_state
slow_y1, _, _, new_slow_state, slow_result = self.slow_solver.step(slow_term, t0, t1, slow_y0, args, slow_solver_state, made_jump)
fast_sol = diffeqsolve(fast_term, t0=t0, t1=t1, y0=y0, args=args, solver_state=fast_state, made_jump=made_jump, **self.fast_diffeqsolve)
[fast_y1] = fast_sol.ys
new_fast_state = fast_sol.solver_state
y1 = (slow_y1, fast_y1)
dense_info = dict(y0=y0, y1=y1)
new_solver_state = (new_slow_state, new_fast_state)
result = diffrax.RESULTS.where(slow_result == diffrax.RESULTS.successful, fast_result, slow_result)
return y1, None, dense_info, new_solver_state, result
... # fill in the other methods
I could see this being something we support directly in Diffrax, if there is enough interest. :)
Thanks for the suggestion. I have a bit of a learning curve before I can get this to work, but I'll try to get started and get a MWE.
In the meantime, I'd like to point out that this is not wholly conceptually unrelated to the issue discussed in #462 and #595. I thought I'd make that connection here, just in case it helps make future implementations more streamlined and/or general.