cyrus-baker

Results 3 comments of cyrus-baker

I have tested `RecursiveCheckpointAdjoint` with `checkpoints=max_steps` with following simple code ```python def vec_field(t, y, args): return -0.1 * y + 0.1 * jnp.sin(0.2 * t) term = dfx.ODETerm(vec_field) t0 =...

It do seems that > setting checkpoints to a very large number (e.g. equal to or exceeding actual step count) already emulate “all-save” But is it exactly “all-save”? Or is...

I made a small demo without using `diffeqsolve` ```python @partial(jax.jit, static_argnums=(1,)) def f(y0, max_steps: int = 100_000): # initialize carry state0 = solver.init(term, t0, t0 + dt0, y0, None) carry0...