Recommendation for real-time publishing of vector field
As far as I understand, there are two main ways of inspecting the progress of a diffeqsolve:
progress_meterthis is called every timestep with the ability to use jax conditionals to do everyntimesteps instead but uses thesolver_staterather than the vector field as an argument and therefore, unless a custom solver is defined that saves a copy of vector field in its internal state then we are limited with the information we can save.saveathere diffrax stores the relevant data in memory and returns asolutionvariable containing all this data after the solve completes. What I would like to do is instead push (and potentially postprocess) data at eachsaveattimestep to a file or external server (such as mlflow) immediately (or potentially asynchronously) so that I can monitor how the simulation is progressing in real-time.
Do you have any recommendations on how to achieve this with the current API? Or a laundry list of required PR's to make something like this possible?
Thanks again! 🙏🏻
In the general case, this is basically live saving of data in a JAX while loop, I've never tried it but I assume the go to would be to have an io_callback (https://docs.jax.dev/en/latest/_autosummary/jax.experimental.io_callback.html#jax.experimental.io_callback). For diffrax, it might be possible to just put these callbacks inside the SaveAt (never tried, but that is what I would first try).
I think you could probably do this with SaveAt(fn=...), where the fun wraps a jax.pure_callback that performs the save.