diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Our of memory : How to optimize the memory use in Diffrax framework?

Open timnotavailable opened this issue 4 months ago • 4 comments

Hello, I wrote a simulator to simulate a ODE system ( with at least 256x256 ODEs in this system), max_steps=1000, solver is Tsit5, adjoints=diffrax.RecursiveCheckpointAdjoint(), however I found the error: Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate xxxx bytes.

I'm a new bird in terms of neural ODE and I'm kindly asking some advice on how to reduce memory usage? (From previous issues maybe I should reduce the max_steps), how about solver? Should I use some easy solver such as 2nd/3nd Runge Kutta solver or ? Any other suggestion on optimizing GPU usage?

Will the adjoint method influencing memory usage? I saw in https://docs.kidger.site/diffrax/api/adjoints/ one can use max_steps or checkpoints in RecursiveCheckpointAdjoint class to control the memory usage, will other adjoint method saving the memory usage? From the past post it is suggested to use RecursiveCheckpointAdjoint as it has been optimized to O(logn), will other method more memory efficient?

If the ODE's rastering timing is known, will the ConstantStepSize better than the adaptive one?

Thanks for your brilliant library!

timnotavailable avatar Sep 08 '25 11:09 timnotavailable

Hey there! So the main option for controlling memory usage is the RecursiveCheckpointAdjoint(checkpoints=...). Smaller values increase computation and decrease memory usage.

If this is not provided explicitly (the default) then it will be set automatically from diffeqsolve(..., max_steps=...), which means this value can also affect memory usage.

I think the choice of solver and step size controller only affect memory usage a small amount, as we only need enough memory to backpropagate a single step at a time.

Make sure to JIT the whole computation - even outside of the grad.

I hope that helps :)

patrick-kidger avatar Sep 08 '25 22:09 patrick-kidger

Hey! Thanks for the answer and it really helps, basically I guess what you means is to use max_steps or checkpoints to balance the memory usage and the computation, after reading your doctoral thesis and reviewing related literature, I still have small questions about adjoint method and the reversible solver:

  1. When using reversible solver, if the dynamic is known (referring to your thesis 2.2.2) and not stiff(or step size is set quite small), upon my understanding is that one does not need adjoint any more, do I still have to use the $RecursiveCheckpointAdjoint(checkpoints=...)$ or maybe other adjoint method?
  2. If I only need the forward process but not the backpropagation gradients, then I should set checkpoints=0 right?
  3. In your thesis 5.3.1.1 General principles on Adaptive versus fixed step solvers: Refering that "This implies a variable computational cost, typically increasing over the course of training as model complexity increase", considering a ODE which has many steps whose dynamic is known and its raster time is $\delta t_0$ , according to this sentence it is suggested to use a constant step size with this stepsize= $\delta t_0$ right ?

Additionally about reversible ODE solver, i saw there is reversible ODE has been tried to embed into diffrax framework https://github.com/sammccallum/reversible-solvers , is it available now?

timnotavailable avatar Sep 09 '25 09:09 timnotavailable

  1. Reversible solves are currently not implemented in Diffrax (though we have a PR to change that, you note!)

  2. If you don't backpropagate then any choice of adjoint will work equally work. In particular, checkpoints are only saved if you backprop so no extra memory is used.

  3. Typically adaptive step sizing is a better choice, as it will typically give a more accurate solution to the ODE. But a small constant step size is reasonable too.

patrick-kidger avatar Sep 10 '25 19:09 patrick-kidger

Gotta! Everything is clear now , thanks!

timnotavailable avatar Sep 10 '25 21:09 timnotavailable