Out of Memory: How to reduce the memory cost for calculating the derivatives wrt the input
I want to calculate the derivatives of the output of a neural ODE w.r.t. the input. However, I met with memory issues. I think the memory-saving method RecursiveCheckpointAdjoint doesn't support forward-mode automatic differentiation (like jax.jvp or jax.jacfwd). In my case, I want to calculate the derivative of f (which is the output of Neural ODE) w.r.t. x and t (which are the inputs) as below:
def f_and_derivs_fast_vec(variables, apply_fn, xt, t1):
"""
xt: (N,2) with columns [x, t]
returns f, f_x, f_t, f_xx each (N, D)
"""
xt = jnp.asarray(xt)
def f_vec(z): # z: (2,) -> (D,)
f = apply_fn(variables, z[None, :], t1=t1) # model returns (1,D)
return f[0] # (D,)
ex = jnp.array([1.0, 0.0]) # d/dx
et = jnp.array([0.0, 1.0]) # d/dt
def one_point(z):
f = f_vec(z) # (D,)
_, fx = jax.jvp(f_vec, (z,), (ex,)) # (D,)
_, ft = jax.jvp(f_vec, (z,), (et,)) # (D,)
def gx(y): return jax.jvp(f_vec, (y,), (ex,))[1] # f_x(y)
_, fxx = jax.jvp(gx, (z,), (ex,)) # (D,)
return f, fx, ft, fxx
f, fx, ft, fxx = jax.vmap(one_point)(xt)
return f, fx, ft, fxx
``
I define my Neural ODE as below:
```python
class Func(eqx.Module):
out_scale: jax.Array
mlp: eqx.nn.MLP
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.out_scale = jnp.array(1.0)
self.mlp = eqx.nn.MLP(
in_size=data_size,
out_size=data_size,
width_size=width_size,
depth=depth,
activation=jnn.swish,
final_activation=jax.nn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.out_scale * self.mlp(y)
class NeuralODE(eqx.Module):
func: Func
def __init__(self, data_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.func = Func(data_size, width_size, depth, key=key)
def __call__(self, t1, y0):
y0 = jnp.asarray(y0).reshape(-1) # (D,)
solution = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=0.0,
t1=t1,
dt0=1e-3,
y0=y0,
stepsize_controller=diffrax.PIDController(rtol=3e-3, atol=3e-6),
saveat=diffrax.SaveAt(t1=True, dense=False),
adjoint=diffrax.RecursiveCheckpointAdjoint(),
)
ys = jnp.asarray(solution.ys).reshape(-1) # (D,)
return ys
I define my model as below:
class PINN(nn.Module):
n_nodes: int
n_layers: int = 1
node_data_size: int = 512 # Size of the data input to the NODE
node_width: int = 64
node_depth: int = 2
def setup(self):
self.hidden_layers = [nn.Dense(self.n_nodes, kernel_init=jax.nn.initializers.he_uniform())
for _ in range(self.n_layers)]
self.integrator = NeuralODE(data_size=self.node_data_size, width_size=self.node_width, depth=self.node_depth, key=jr.PRNGKey(0))
def encode_input(self, inputs):
x = inputs
for idx, dense in enumerate(self.hidden_layers):
x = dense(x)
if idx == 0:
x = 2 * jnp.pi * x
x = jnp.sin(x)
return x
@nn.compact
def __call__(self, inputs, t1=0):
xt = inputs # shape (N, 1)
f_raw = self.encode_input(xt)
f_last = jax.vmap(self.integrator, in_axes=(None, 0))(t1, f_raw) # (N, 512)
return f_last
The problem I met is
2025-09-16 19:18:24.767882: E external/xla/xla/service/slow_operation_alarm.cc:65] ******************************** [Compiling module jit_update] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. ******************************** 2025-09-16 19:23:11.188537: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 33.78GiB (36271006133 bytes) by rematerialization; only reduced to 62.50GiB (67109664958 bytes), down from 62.50GiB (67109716246 bytes) originally 2025-09-16 19:23:34.487702: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 7m9.719906955s ******************************** [Compiling module jit_update] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. ******************************** 2025-09-16 19:24:00.325370: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 64.71GiB (rounded to 69477018624)requested by op 2025-09-16 19:24:00.325549: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] *___________________________________________________________________________________________________ E0916 19:24:00.325599 2586479 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 69477018512 bytes.
Do you have any suggestions of solving this problem? I really appreciate your help and time!
If you want to use forward mode to save memory there is diffrax.ForwardMode()
Does this seem to be an OOM during runtime or an OOM during compilation? If the latter then it might be due to closing over a very large constant. If so then make sure these are inputs to the JIT'd region rather. (I can see you appear to be using Flax. If this is the culprit, then I have no idea how you should avoid this when using Flax though.)
If nothing else, I can see that you're getting a 'very slow compile' warning, which is at least suggestive that you're describing a malformed JAX program, i.e. probably unnecessarily large. I can see you appear to have a couple of places where things could improve here, e.g. to combine your two jax.jvps via a jax.vmap, or to use a jax.lax.scan instead of the for loop over layers.