diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Recommendation for real-time publishing of vector field

Open jpbrodrick89 opened this issue 7 months ago • 2 comments

As far as I understand, there are two main ways of inspecting the progress of a diffeqsolve:

  • progress_meter this is called every timestep with the ability to use jax conditionals to do every n timesteps instead but uses the solver_state rather than the vector field as an argument and therefore, unless a custom solver is defined that saves a copy of vector field in its internal state then we are limited with the information we can save.
  • saveat here diffrax stores the relevant data in memory and returns a solution variable containing all this data after the solve completes. What I would like to do is instead push (and potentially postprocess) data at each saveat timestep to a file or external server (such as mlflow) immediately (or potentially asynchronously) so that I can monitor how the simulation is progressing in real-time.

Do you have any recommendations on how to achieve this with the current API? Or a laundry list of required PR's to make something like this possible?

Thanks again! 🙏🏻

jpbrodrick89 avatar Jun 03 '25 13:06 jpbrodrick89

In the general case, this is basically live saving of data in a JAX while loop, I've never tried it but I assume the go to would be to have an io_callback (https://docs.jax.dev/en/latest/_autosummary/jax.experimental.io_callback.html#jax.experimental.io_callback). For diffrax, it might be possible to just put these callbacks inside the SaveAt (never tried, but that is what I would first try).

lockwo avatar Jun 03 '25 14:06 lockwo

I think you could probably do this with SaveAt(fn=...), where the fun wraps a jax.pure_callback that performs the save.

patrick-kidger avatar Jun 03 '25 15:06 patrick-kidger