Recommendation for saving auxilary data
Often times we compute auxiliary variables each timestep within vector_field that we would like to save as part of processing for future reference/sanity checking. We would like to avoid having to manually recalculate these within postprocessing (for confident debugging). What is the best practice way of achieving this? One option is to have a custom stepper where the auxilary data is not "stepped"/incremented but instead updated/replaced but doesn't feel ideal. Ideally, something like a has_aux like in the standard jax.vjp could be supported. Any other suggestions? Thanks 🙂
While it might involve some re-computation, this sounds like it could be done with a SaveAt(fn=)
Yup, I think @lockwo's approach is probably best.
In general we don't support saving aux output from a vector field because it's not really clear how this is defined in general: for example when making multiple vector field calls inside the implicit solve or a stiff solver, or when time progresses nonmonotonically due to step rejection.
This means that aux output is best thought of as a function of (t, y, args) and handled via SaveAt(fn=...), even if that just involves calling the vector field -- rather than as an output of the vector field directly.