diffrax
diffrax copied to clipboard
Is it possible to obtain the steps of the trajectories?
Hello,
I was wondering if is it possible to use some kind of callback function on the diffeqsolve function such that I can get not only the last step of the trajectory within the specified time, but also all the intermediate steps. If so, can you give me a quick example on how to implement it myself?
Call sol = diffeqsolve(..., saveat=SaveAt(steps=True)).
Then sol.ys will be an array of all the locations at which the differential equation solver made a step. (As JAX works using statically-sized arrays, then the array will be of size max_steps and will be padded at the end with infs.)