diffrax
diffrax copied to clipboard
Saving metrics during solve
Hi, thanks again for the library.
we have a numerical integration we perform using lax.scan
. We'd like to port it over to diffrax
. I wanted to get your thoughts on the shortest path to this implementation.
From https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html
def scan(f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
we rely quite a bit on the implicit append
and stack
calls that happen inside lax.scan
. I think the core of my question is, is there a way to avoid the pre-allocation as you suggest in https://github.com/patrick-kidger/diffrax/issues/60#issuecomment-1034768386 . My guess is the answer is no but figured I'd see if you had any ideas.
TY in advance!
Do I understand that:
- As in #60, you're interested in logging metrics during the solve, and want to save something at every step?
- As you're currently using a
lax.scan
, you're probably using a fixed step size rather than an adaptive step size?
If this is the case then this is totally doable. Follow the approach as in #60 to create a wrapper solver that stores your metrics. metric_state
should be a buffer of the appropriate size (equal to the number of steps you'll be taking), which you update using "in-place updates". (buffer = buffer.at[step].set(value)
.) Now the analogous behaviour to a lax.scan
in Diffrax is to pass stepsize_controller=ConstantStepSize(compile_steps=True)
, and if you do this then you should find that your "in-place updates" really do get complied down to efficient in-place updates!
Can you give me some examples of the sort if thing you're trying to save, by the way? That this has come up twice now suggests that it may be worth creating an easy way to save user-specified metrics. (And in particular without needing to worry about this business of the efficiency of in-place updates.)
Do I understand that:
- As in Logging metrics during an ODE solve #60, you're interested in logging metrics during the solve, and want to save something at every step?
- As you're currently using a
lax.scan
, you're probably using a fixed step size rather than an adaptive step size?If this is the case then this is totally doable. Follow the approach as in #60 to create a wrapper solver that stores your metrics.
metric_state
should be a buffer of the appropriate size (equal to the number of steps you'll be taking), which you update using "in-place updates". (buffer = buffer.at[step].set(value)
.) Now the analogous behaviour to alax.scan
in Diffrax is to passstepsize_controller=ConstantStepSize(compile_steps=True)
, and if you do this then you should find that your "in-place updates" really do get complied down to efficient in-place updates!
Yes to both!
Okay sounds good. I was just hoping that I wouldn't have to preallocate metric_state
bc, well, I don't do it now and therefore, it's more work :P
Can you give me some examples of the sort if thing you're trying to save, by the way? That this has come up twice now suggests that it may be worth creating an easy way to save user-specified metrics. (And in particular without needing to worry about this business of the efficiency of in-place updates.)
Yes, of course.
I am solving a plasma physics PDE where it's impractical, if not impossible (due to speed and storage), to store the full state at every timestep. So, we calculate some moments, interpolations, averages etc. and track those over time. See fig 5 and 6 for examples of reduced versions of the full state over time.
This is definitely the case for other PDEs in plasma physics. I am assuming the same happens in PDE solves in other domains.
So given the state y
(at each timestep), then at the moment, with SaveAt(steps=True)
, we simply save y
. Perhaps what could be done is to allow inserting a user-specified function (t, y) -> anything
here that determines what to save, defaulting to
def just_state(t: Scalar, y: Array):
return y
Perhaps it could/should also be a function of solver_state
etc. etc.
That's exactly what I do right now in the scan (in a Haiku module, I know, I know, transform magic, i'm certainly thinking about equinox)
def one_step(y, current_params):
y = evolve_PDE(y)
temp_storage = self.storage_step(y)
return y, temp_storage
self.storage_step
is the postprocessing step
Sounds good! I've adjusted the title of this issue to make this a feature request. (Which if you ever feel like digging into the guts of Diffrax I'd be happy to shepard a PR on.)
i am in, is this where we'll be working?
https://github.com/patrick-kidger/diffrax/blob/cec091c5e4cc4311f64ae3aa09a371db5fe766ee/diffrax/integrate.py#L246
Yep! It's this big block of code. The idea will be to change from state.ys
-- which is where we save our output -- to state.out
.
Now the bad news is that this part of the code is pretty complicated to read, as it needs to carefully make sure it gets good performance. Fortunately the good news is that you shouldn't need to change any of that -- just carefully figure out what's what, and replace saving ys
with saving save_fn(ys)
.
I'd suggest that the saving-function should be specified as part of the diffeqsolve(..., saveat=...)
argument, and that when called it should take state
as an argument.
As a nice bonus, I've realised that this makes the solver_state
, controller_state
, made_jump
arguments to SaveAt
superfluous: these can be obtained simply by using the appropriate metric-saving function.
Looking at this again, do I want to trace back the ys=ys
line to be more of a out=out
line?
new_state = _State(
y=y,
tprev=tprev,
tnext=tnext,
made_jump=made_jump,
solver_state=solver_state,
controller_state=controller_state,
result=result,
num_steps=num_steps,
num_accepted_steps=num_accepted_steps,
num_rejected_steps=num_rejected_steps,
saveat_ts_index=saveat_ts_index,
ts=ts,
ys=ys,
save_index=save_index,
dense_ts=dense_ts,
dense_infos=dense_infos,
dense_save_index=dense_save_index,
)
So given that the idea here is to save these summary statistics instead of y
: we could always just use the ys
buffer itself.
Then , at all the places where we currently do ys.at[index].set(y)
, we instead do ys.at[index].set(save(t, y, args))
, where save
is a function passed to diffeqsolve
, which defaults to save = lambda t, y, args: y
.
Closing as this has been completed in the about-to-be-released v0.3.0.
For the original PDE use-case, you may find the new SubSaveAt functionality useful, in that it allows you to save the terminal value in its entirety, and just the statistics of the evolving solution. See the examples on that page.
I'd also welcome any feedback on the new nonlinear heat PDE example.
Thanks for helping me get this through. SubSaveAt
is a nice feature add too!
I think I'm on to a bug, but I don't have a minimum repro for this yet so I thought I'd see if something comes to your mind right away --
I'm getting NaNs when trying to take a gradient of one of these reduced quantities. The gradient comes through fine when using a fn=None
with a ts=
. I'm only using jnp.interp
in my fn
so I'd like to say it should be diff'able.
Any thoughts? I'll try to get a min repro going in the meantime...
I think I have this traced down to storing abs
and angle
of complex values. This happens even if I take the gradient of a different quantity that never undergoes a complex transformation. So it's probably related to the complex support issue in https://github.com/patrick-kidger/diffrax/issues/96