diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Possible to get BacksolveAdjoint backward pass solution/stats?

Open jjyyxx opened this issue 2 years ago • 3 comments

Thanks for this excellent library!

When using BacksolveAdjoint adjoint method, it's easy to get (and log) forward pass stats with

sol = diffrax.diffeqsolve(
    terms, solver, t_0, t_T, initial_step_size, init_y, args,
    saveat=saveat, stepsize_controller=stepsize_controller,
    adjoint=diffrax.BacksolveAdjoint()
)
stats = sol.stats
# ...
(loss, stats), grads = jax.value_and_grad(loss_fn, has_aux=True)(policy_params, init_states)

I believe the backward pass also involves solving ODE, but I did not figure out a way to get its stats. Could you suggest how to achieve this functionality? Or is it something limited by JAX API?

jjyyxx avatar Jun 01 '22 03:06 jjyyxx

I'd be happy to offer this if you can determine an API for it. Because this happens inside a custom_vjp then as far as I can see JAX doesn't offer anywhere to return auxiliary information.

If you want a hacky version then you could just modify the source code for BacksolveAdjoint and print/save the stats as a side-effect using jax.experimental.host_callback.

patrick-kidger avatar Jun 01 '22 08:06 patrick-kidger

Thanks for your quick reply!

For the first way, it may relate to https://github.com/google/jax/issues/2796 and https://github.com/google/jax/pull/2574, which currently does not have an elegant solution. A dummy stats passed as argument proposed in https://github.com/google/jax/pull/2574 might work, but I'm not sure if you consider it worth implementing.

For the hacky way, I implemented a draft as follows

diff --git a/adjoint.py b/adjoint2.py
index 79dae8e..3658339 100644
--- a/adjoint.py
+++ b/adjoint2.py
@@ -10,6 +10,7 @@ from .misc import nondifferentiable_output, ω
 from .saveat import SaveAt
 from .term import AbstractTerm, AdjointTerm
 
+import jax.experimental.host_callback
 
 class AbstractAdjoint(eqx.Module):
     """Abstract base class for all adjoint methods."""
@@ -180,7 +181,7 @@ def _loop_backsolve_bwd(
 
     def _scan_fun(_state, _vals, first=False):
         _t1, _t0, _y0, _grad_y0 = _vals
-        _a0, _solver_state, _controller_state = _state
+        _a0, _solver_state, _controller_state, _stats = _state
         _a_y0, _a_diff_args0, _a_diff_term0 = _a0
         _a_y0 = (_a_y0**ω + _grad_y0**ω).ω
         _aug0 = (_y0, _a_y0, _a_diff_args0, _a_diff_term0)
@@ -205,10 +206,12 @@ def _loop_backsolve_bwd(
         _a1 = (_a_y1, _a_diff_args1, _a_diff_term1)
         _solver_state = _sol.solver_state
         _controller_state = _sol.controller_state
+        _sol.stats.pop("max_steps")
+        _stats = _sol.stats if _stats is None else jax.tree_map(lambda x, y: x + y, _stats, _sol.stats)
 
-        return (_a1, _solver_state, _controller_state), None
+        return (_a1, _solver_state, _controller_state, _stats), None
 
-    state = ((zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms), None, None)
+    state = ((zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms), None, None, None)
     del zeros_like_y, zeros_like_diff_args, zeros_like_diff_terms
 
     # We always start backpropagating from `ts[-1]`.
@@ -240,7 +243,7 @@ def _loop_backsolve_bwd(
             val = (ts[0], ts[1], ω(ys)[1].ω, ω(grad_ys)[1].ω)
             state, _ = _scan_fun(state, val, first=True)
 
-        aug1, _, _ = state
+        aug1, _, _, _stats = state
         a_y1, a_diff_args1, a_diff_terms1 = aug1
         a_y1 = (ω(a_y1) + ω(grad_ys)[0]).ω
 
@@ -261,8 +264,9 @@ def _loop_backsolve_bwd(
             val = (t0, ts[0], ω(ys)[0].ω, ω(grad_ys)[0].ω)
             state, _ = _scan_fun(state, val, first=True)
 
-        aug1, _, _ = state
+        aug1, _, _, _stats = state
         a_y1, a_diff_args1, a_diff_terms1 = aug1
+    jax.experimental.host_callback.id_print(_stats)
 
     return a_y1, a_diff_args1, a_diff_terms1

I'm unsure if I understand your code correctly. Could you help review it? The hacky solution luckily did not have obvious negative impact on performance, but the biggest pain is that it's inconsistent with the overall code execution flow (e.g. logging to tensorboard) and requires some more hacks with id_tap and my code outside jax.jit.

jjyyxx avatar Jun 02 '22 03:06 jjyyxx

I probably wouldn't mutate using pop, but other than that it LGTM.

Something else that might make logging easier is using equinox.experimental.{get,set}_state instead of host_callback directly. See here. This wraps host_callback to provide an interface for stateful operations; in this case saving data and retrieving it at a later time.

patrick-kidger avatar Jun 02 '22 09:06 patrick-kidger