diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Recommendation for saving auxilary data

Open jpbrodrick89 opened this issue 7 months ago • 2 comments

Often times we compute auxiliary variables each timestep within vector_field that we would like to save as part of processing for future reference/sanity checking. We would like to avoid having to manually recalculate these within postprocessing (for confident debugging). What is the best practice way of achieving this? One option is to have a custom stepper where the auxilary data is not "stepped"/incremented but instead updated/replaced but doesn't feel ideal. Ideally, something like a has_aux like in the standard jax.vjp could be supported. Any other suggestions? Thanks 🙂

jpbrodrick89 avatar Jun 03 '25 13:06 jpbrodrick89

While it might involve some re-computation, this sounds like it could be done with a SaveAt(fn=)

lockwo avatar Jun 03 '25 14:06 lockwo

Yup, I think @lockwo's approach is probably best.

In general we don't support saving aux output from a vector field because it's not really clear how this is defined in general: for example when making multiple vector field calls inside the implicit solve or a stiff solver, or when time progresses nonmonotonically due to step rejection.

This means that aux output is best thought of as a function of (t, y, args) and handled via SaveAt(fn=...), even if that just involves calling the vector field -- rather than as an output of the vector field directly.

patrick-kidger avatar Jun 03 '25 15:06 patrick-kidger