diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

Out of Memory: How to reduce the memory cost for calculating the derivatives wrt the input

Open zhlfmyzhh opened this issue 3 months ago • 2 comments

I want to calculate the derivatives of the output of a neural ODE w.r.t. the input. However, I met with memory issues. I think the memory-saving method RecursiveCheckpointAdjoint doesn't support forward-mode automatic differentiation (like jax.jvp or jax.jacfwd). In my case, I want to calculate the derivative of f (which is the output of Neural ODE) w.r.t. x and t (which are the inputs) as below:

def f_and_derivs_fast_vec(variables, apply_fn, xt, t1):
    """
    xt: (N,2) with columns [x, t]
    returns f, f_x, f_t, f_xx each (N, D)
    """
    xt = jnp.asarray(xt)

    def f_vec(z):                                # z: (2,) -> (D,)
        f = apply_fn(variables, z[None, :], t1=t1)   # model returns (1,D)
        return f[0]                                   # (D,)

    ex = jnp.array([1.0, 0.0])  # d/dx
    et = jnp.array([0.0, 1.0])  # d/dt

    def one_point(z):
        f     = f_vec(z)                             # (D,)
        _, fx  = jax.jvp(f_vec, (z,), (ex,))         # (D,)
        _, ft  = jax.jvp(f_vec, (z,), (et,))         # (D,)
        def gx(y): return jax.jvp(f_vec, (y,), (ex,))[1]  # f_x(y)
        _, fxx = jax.jvp(gx, (z,), (ex,))            # (D,)
        return f, fx, ft, fxx

    f, fx, ft, fxx = jax.vmap(one_point)(xt)
    return f, fx, ft, fxx
``

I define my Neural ODE as below:

```python
class Func(eqx.Module):
    out_scale: jax.Array
    mlp: eqx.nn.MLP

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.out_scale = jnp.array(1.0)
        self.mlp = eqx.nn.MLP(
            in_size=data_size,
            out_size=data_size,
            width_size=width_size,
            depth=depth,
            activation=jnn.swish,
            final_activation=jax.nn.tanh,
            key=key,
        )

    def __call__(self, t, y, args):
        return self.out_scale * self.mlp(y)

class NeuralODE(eqx.Module):
    func: Func

    def __init__(self, data_size, width_size, depth, *, key, **kwargs):
        super().__init__(**kwargs)
        self.func = Func(data_size, width_size, depth, key=key)

    def __call__(self, t1, y0):
        y0 = jnp.asarray(y0).reshape(-1)  # (D,)
        solution = diffrax.diffeqsolve(
            diffrax.ODETerm(self.func),
            diffrax.Tsit5(),
            t0=0.0,
            t1=t1,
            dt0=1e-3,
            y0=y0,
            stepsize_controller=diffrax.PIDController(rtol=3e-3, atol=3e-6),
            saveat=diffrax.SaveAt(t1=True, dense=False),
            adjoint=diffrax.RecursiveCheckpointAdjoint(),
        )
        ys = jnp.asarray(solution.ys).reshape(-1)  # (D,)

        return ys

I define my model as below:

class PINN(nn.Module):
    n_nodes: int
    n_layers: int = 1
    node_data_size: int = 512  # Size of the data input to the NODE
    node_width: int = 64
    node_depth: int = 2

    def setup(self):
        self.hidden_layers = [nn.Dense(self.n_nodes, kernel_init=jax.nn.initializers.he_uniform())
                              for _ in range(self.n_layers)]
        self.integrator = NeuralODE(data_size=self.node_data_size, width_size=self.node_width, depth=self.node_depth, key=jr.PRNGKey(0))

    def encode_input(self, inputs):
        x = inputs
        for idx, dense in enumerate(self.hidden_layers):
            x = dense(x)
            if idx == 0:
                x = 2 * jnp.pi * x
            x = jnp.sin(x)
        return x

    @nn.compact
    def __call__(self, inputs, t1=0):
        xt = inputs  # shape (N, 1)

        f_raw = self.encode_input(xt)
        f_last = jax.vmap(self.integrator, in_axes=(None, 0))(t1, f_raw) # (N, 512)

        return f_last

The problem I met is

2025-09-16 19:18:24.767882: E external/xla/xla/service/slow_operation_alarm.cc:65] ******************************** [Compiling module jit_update] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. ******************************** 2025-09-16 19:23:11.188537: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 33.78GiB (36271006133 bytes) by rematerialization; only reduced to 62.50GiB (67109664958 bytes), down from 62.50GiB (67109716246 bytes) originally 2025-09-16 19:23:34.487702: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 7m9.719906955s ******************************** [Compiling module jit_update] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results. ******************************** 2025-09-16 19:24:00.325370: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 64.71GiB (rounded to 69477018624)requested by op 2025-09-16 19:24:00.325549: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] *___________________________________________________________________________________________________ E0916 19:24:00.325599 2586479 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 69477018512 bytes.

Do you have any suggestions of solving this problem? I really appreciate your help and time!

zhlfmyzhh avatar Sep 16 '25 13:09 zhlfmyzhh

If you want to use forward mode to save memory there is diffrax.ForwardMode()

lockwo avatar Sep 16 '25 16:09 lockwo

Does this seem to be an OOM during runtime or an OOM during compilation? If the latter then it might be due to closing over a very large constant. If so then make sure these are inputs to the JIT'd region rather. (I can see you appear to be using Flax. If this is the culprit, then I have no idea how you should avoid this when using Flax though.)

If nothing else, I can see that you're getting a 'very slow compile' warning, which is at least suggestive that you're describing a malformed JAX program, i.e. probably unnecessarily large. I can see you appear to have a couple of places where things could improve here, e.g. to combine your two jax.jvps via a jax.vmap, or to use a jax.lax.scan instead of the for loop over layers.

patrick-kidger avatar Sep 18 '25 21:09 patrick-kidger