Question: DirectAdjoint is faster than RecursiveCheckpointAdjoint?
Hi,
According to the suggestions in adjoints docs, the RecursiveCheckpointAdjoint method, given enough checkpoints, should be faster than DirectAdjoint. But In my practice, it turns out that DirectAdjoint is faster. Is there wrong in my understanding?
Example
An example can be shown using the neural_cde tutorial:
- If use the default setting:
adjoint = RecursiveCheckpointAdjoint(), the total run time (without the first step/compilation time ) is 1.38s - if use
adjoint = RecursiveCheckpointAdjoint(checkpoints=4096), the run time is 1.18s - if use 'adjoint = diffrax.DirectAdjoint()', the run time is 0.77s
Environment
-
jax == 0.4.29jaxlib == 0.4.29 -
diffrax == 0.6.0 - platform: CPU
Hi,
what exactly are you benchmarking? The runtime of main?
Comparing the runtimes of NeuralCDE.__call__ with RecursiveCheckpointAdjoint() and DirectAdjoint() I get equivalent times in microbenchmarks (1.93 ms for the direct adjoint and 1.94 ms for the recursive one, mean of 1000 loops). Evaluating their gradients takes the same amount of time as well (8.35 vs 8.33 ms, mean of 100 loops).
Both times I used sample_ts, sample_coeffs as input, same as done for the plotting section of the example.
Hi,
I just simply count the runtime of the training steps in main(), using the command below
total_time = 0
for step, data_i in zip(
range(steps), dataloader((ts, labels) + coeffs, batch_size, key=loader_key)
):
start = time.time()
bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
end = time.time()
time_i = end - start
if step > 0: # don't count the compilation time
total_time += time_i
print(
f"Step: {step}, Loss: {bxe}, Accuracy: {acc}, Computation time: "
f"{end - start}"
)
print(f'total time: {total_time}')
If I benchmark the make_step() function using this command.
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
The result is 64.2ms for RecursiveCheckpointAdjoint() and 35.1ms for DirectAdjoint().
It's really confusing
Have you called make_step before, to ensure that everything is compiled? Something like this
run_fn = eqx.filter_jit(fn)
_ = run_fn(inputs)
%timeit run_fn(inputs).block_until_ready()
Yes, I add this command below the training steps like this.
for step, data_i in zip(
range(steps), dataloader((ts, labels) + coeffs, batch_size, key=loader_key)
):
bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
And the make_step() function inside this notebook is jitted with the @eqx.filter_jit
I cannot reproduce this, I still get equivalent runtimes if I'm adding a %timeit exactly where you do. (134 and 137 ms, mean of 10 loops.)
How/where do you specify which adjoint to use?
@CoastEgo -- do you have a copy-pastable MWE? (E.g. by starting with the neural CDE example and then minimising it down to something that fits in a GitHub message.) Just to be sure we're all running exactly the same code!
Also, what version of Equinox are you using? The underlying implementation of the while loops (which these adjoint methods use) belong to Equinox.
It's definitely expected that RecursiveCheckpointAdjoint should be the better choice!
It's definitely expected that
RecursiveCheckpointAdjointshould be the better choice!
Even for such a small example? The solver just takes 50 steps.
It's definitely expected that
RecursiveCheckpointAdjointshould be the better choice!Even for such a small example? The solver just takes 50 steps.
Yup! In fact especially so. The cost of DirectAdjoint actually grows primarily with max_steps, not with the number of steps actually taken.
Sorry for the confusion! Here is the code
click here
import math
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax # https://github.com/deepmind/optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
data_size: int
hidden_size: int
def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.data_size = data_size
self.hidden_size = hidden_size
self.mlp = eqx.nn.MLP(
in_size=hidden_size,
out_size=hidden_size * data_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
# Note the use of a tanh final activation function. This is important to
# stop the model blowing up. (Just like how GRUs and LSTMs constrain the
# rate of change of their hidden states.)
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
linear: eqx.nn.Linear
adjoint_state: int
def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
self.adjoint_state = adjoint_state
def __call__(self, ts, coeffs):
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = None
if self.adjoint_state == 0:
adjoint = diffrax.RecursiveCheckpointAdjoint()
else:
adjoint = diffrax.DirectAdjoint()
# adjoint = diffrax.DirectAdjoint()
y0 = self.initial(control.evaluate(ts[0]))
saveat = diffrax.SaveAt(t1=True)
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=saveat,
adjoint=adjoint,
)
(prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
return prediction
def get_data(dataset_size, add_noise, *, key):
theta_key, noise_key = jr.split(key, 2)
length = 100
theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
ys = jax.vmap(
lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
)(y0, ts)
ys = jnp.concatenate([ts[:, :, None], ys], axis=-1) # time is a channel
ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
if add_noise:
ys = ys + jr.normal(noise_key, ys.shape) * 0.1
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
labels = jnp.zeros((dataset_size,))
labels = labels.at[: dataset_size // 2].set(1.0)
_, _, data_size = ys.shape
return ts, coeffs, labels, data_size
def main(
dataset_size=256,
add_noise=False,
batch_size=32,
lr=1e-2,
steps=20,
hidden_size=8,
width_size=128,
depth=1,
seed=5678,
):
key = jr.PRNGKey(seed)
train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
ts, coeffs, labels, data_size = get_data(
dataset_size, add_noise, key=train_data_key
)
# Training loop like normal.
@eqx.filter_jit
def loss(model, ti, label_i, coeff_i):
pred = jax.vmap(model)(ti, coeff_i)
# Binary cross-entropy
bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
bxe = -jnp.mean(bxe)
acc = jnp.mean((pred > 0.5) == (label_i == 1))
return bxe, acc
grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
@eqx.filter_jit
def make_step(model, data_i, opt_state):
ti, label_i, *coeff_i = data_i
(bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return bxe, acc, model, opt_state
optim = optax.adam(lr)
data_all = (ts, labels) + coeffs
data_i = (a[:batch_size] for a in data_all)
print('recursive checkpoint adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
data_i = (a[:batch_size] for a in data_all)
print('direct adjoint')
optim = optax.adam(lr)
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = make_step(model, data_i, opt_state)
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main()
This will give a result with the version of equinox == 0.11.4
#recursive checkpoint adjoint
#74.4 ms ± 95.1 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
#direct adjoint
#46.5 ms ± 43.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
When I upgrade all of my packages jax==0.4.38 jaxlib==0.4.38 diffrax==0.6.2 equinox==0.11.11, this problem disappears.
This will give a result like this, same performance but slower 😢
#recursive checkpoint adjoint
#160 ms ± 6.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#direct adjoint
#156 ms ± 3.49 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Hi both,
on newest everything, using your code @CoastEgo, the direct adjoint is a little faster for me (126 ms vs 156 ms). This is when benchmarking make_step, exactly as you do.
However, this does not carry over to simply solving the differential equation, or to taking its derivative. If I run the neural CDE from the example, the integration alone is ~20-25 % faster with the checkpointed adjoint. This increases to 33 % for the gradients. You can paste this code below yours:
key = jr.key(0)
ts, coeffs, labels, data_size = get_data(dataset_size=256, add_noise=False, key=key)
sample_ts = ts[-1]
sample_coeffs = tuple(c[-1] for c in coeffs)
recursive = NeuralCDE(data_size, 8, 128, 1, adjoint_state=0, key=key)
direct = NeuralCDE(data_size, 8, 128, 1, adjoint_state=1, key=key)
run_recursive = eqx.filter_jit(recursive)
run_direct = eqx.filter_jit(direct)
run_grad_recursive = eqx.filter_jit(eqx.filter_grad(recursive))
run_grad_direct = eqx.filter_jit(eqx.filter_grad(direct))
_ = run_recursive(sample_ts, sample_coeffs)
_ = run_direct(sample_ts, sample_coeffs)
_ = run_grad_recursive(sample_ts, sample_coeffs)
_ = run_grad_direct(sample_ts, sample_coeffs)
print("Timing recursive checkpoint adjoint")
%timeit run_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint")
%timeit run_direct(sample_ts, sample_coeffs).block_until_ready()
print("Timing recursive checkpoint adjoint with gradients")
%timeit run_grad_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint with gradients")
%timeit run_grad_direct(sample_ts, sample_coeffs).block_until_ready()
Edit: update to switch to a better example, and include gradients.
Hi @johannahaffner, with newest packages, I got the same result with your code. I guess the different benchmark results between make_step() and NeuralCDE.__call__() are because of two reasons.
- The first reason is that inside
make_step(), we use `jax.vmap' version of model. If I run your code, I will get the result like this. (recursive is twice faster!)
Timing recursive checkpoint adjoint
3 ms ± 53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing direct adjoint
3.54 ms ± 222 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing recursive checkpoint adjoint with gradients
9.8 ms ± 447 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Timing direct adjoint with gradients
18.5 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
But if I run vmap version of your code, the result is different (direct adjoint is slightly faster). This may be related with while_loop inside vmap?
batch_size = 32
key = jr.key(0)
ts, coeffs, labels, data_size = get_data(dataset_size=256, add_noise=False, key=key)
sample_ts = ts[:batch_size]
sample_coeffs = tuple(c[:batch_size] for c in coeffs)
recursive = NeuralCDE(data_size, 8, 128, 1, adjoint_state=0, key=key)
direct = NeuralCDE(data_size, 8, 128, 1, adjoint_state=1, key=key)
@eqx.filter_jit
def run_recursive(ts, coeffs):
result = jax.vmap(recursive)(ts, coeffs)
return result.mean()
@eqx.filter_jit
def run_direct(ts, coeffs):
result = jax.vmap(direct)(ts, coeffs)
return result.mean()
run_grad_recursive = eqx.filter_jit(eqx.filter_grad(run_recursive))
run_grad_direct = eqx.filter_jit(eqx.filter_grad(run_direct))
_ = run_recursive(sample_ts, sample_coeffs)
_ = run_direct(sample_ts, sample_coeffs)
_ = run_grad_recursive(sample_ts, sample_coeffs)
_ = run_grad_direct(sample_ts, sample_coeffs)
print("Timing recursive checkpoint adjoint")
%timeit run_recursive(sample_ts, sample_coeffs).block_until_ready()
print("Timing direct adjoint")
%timeit run_direct(sample_ts, sample_coeffs).block_until_ready()
print("Timing recursive checkpoint adjoint with gradients")
%timeit jax.block_until_ready(run_grad_recursive(sample_ts, sample_coeffs))
print("Timing direct adjoint with gradients")
%timeit jax.block_until_ready(run_grad_direct(sample_ts, sample_coeffs))
# Timing recursive checkpoint adjoint
# 52.8 ms ± 336 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing direct adjoint
# 52.9 ms ± 116 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing recursive checkpoint adjoint with gradients
# 169 ms ± 7.93 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# Timing direct adjoint with gradients
# 147 ms ± 5.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Edit: delete the second reason. It turns out to be my mistake
I've just ran this on my machine and:
- With the benchmarking code provided by @CoastEgo then I can reproduce the difference between
DirectAdjointandRecursiveCheckpointAdjoint. That's unfortunate! - When I replace the vmap with a single model invocation -- complete hack, but this:
then the performance is again in@eqx.filter_jit def loss(model, ti, label_i, coeff_i): ti, label_i, coeff_i = jax.tree.map(lambda x: x[0], (ti, label_i, coeff_i)) pred = model(ti, coeff_i)[None] label_i = label_i[None] ...RecursiveCheckpointAdjoint's favour. So indeed it definitely seems like a vmap issue. - I think we'll need to simplify this down, e.g. to just an ODE, or even just a direct invocation of the underlying
eqxi.while_loop(..., kind="bounded")(=DirectAdjoint) vseqxi.while_loop(..., kind="checkpointed")(=RecursiveCheckpointAdjoint).
At least for the speed difference between Equinox 0.11.4 and Equinox 0.11.11 then I'm able to resolve this one pretty easily :) I think all that's happening is that you're also picking up a change in JAX version at the same time, and modern JAX has a known performance bug that is resolved by setting
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
with this I'm able to get your first faster set of numbers even on the latest versions of JAX and Equinox. I'll go ahead and ask over on the JAX issue tracker what the plan is for that. (EDIT: https://github.com/jax-ml/jax/discussions/25711)
Note that your example does have a small bug by the way: you have a generator data_i = (a[:batch_size] for a in data_all) rather than a tuple data_i = tuple(a[:batch_size] for a in data_all). Strangely this doesn't seem to impact the results for me,
Edit: my bad, the example below only has recursive checkpoint adjoint loose its edge under vmap if not jitted. Just noticed it was missing this after I had posted.
Under JIT, I can now not reproduce the performance-drop-under-vmap issue when using the introductory CDE example with a quadratic path.
import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
class QuadraticPath(dfx.AbstractPath):
@property
def t0(self):
return 0
@property
def t1(self):
return 3
def evaluate(self, t0, t1=None, left=True):
del left
if t1 is not None:
return self.evaluate(t1) - self.evaluate(t0)
return t0 ** 2
# args unused - can be anything
vector_field = lambda t, y, args: -y
control = QuadraticPath()
term = dfx.ControlTerm(vector_field, control).to_ode()
solver = dfx.Dopri5()
def make_solve(term, solver, adjoint, vmap=False):
def solve(args):
return dfx.diffeqsolve(term, solver, 0, 3, 0.05, 1, args, adjoint=adjoint).ys
if vmap:
return eqx.filter_jit(eqx.filter_vmap(solve))
else:
return eqx.filter_jit(solve)
direct = dfx.DirectAdjoint()
recursive = dfx.RecursiveCheckpointAdjoint()
direct_solve = make_solve(term, solver, direct)
recursive_solve = make_solve(term, solver, recursive)
vmap_direct_solve = make_solve(term, solver, direct, vmap=True)
vmap_recursive_solve = make_solve(term, solver, recursive, vmap=True)
_ = direct_solve(None) # warmup
_ = recursive_solve(None)
dummy_args = jnp.zeros((8,))
_ = vmap_direct_solve(dummy_args)
_ = vmap_recursive_solve(dummy_args)
print("Timing direct adjoint: quadratic path")
%timeit direct_solve(None).block_until_ready()
print("Timing recursive checkpoint adjoint: quadratic path")
%timeit recursive_solve(None).block_until_ready()
print("Timing direct adjoint: quadratic path (vmap)")
%timeit vmap_direct_solve(dummy_args).block_until_ready()
print("Timing recursive checkpoint adjoint: quadratic path (vmap)")
%timeit vmap_recursive_solve(dummy_args).block_until_ready()
Also looking at the just while loops I see the same thing (i.e. checkpoint being faster)
Code + Results
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import equinox as eqx
import equinox.internal as eqxi
import jax.numpy as jnp
import jax
t0 = 0.0
t1 = 3.0
dt = 0.1
N = int((t1 - t0) / dt)
t = jnp.arange(t0, t1 + dt, dt)
param = -1.0
def cond_fun(carry):
i, y_array, y_cur = carry
return i < N
def body_fun(carry):
i, y_array, y_cur = carry
y_next = y_cur + dt * (param * y_cur)
return i+1, y_array.at[i+1].set(y_next), y_next
def cond_fun(carry):
i, y_array, y_cur = carry
return i < N
def body_fun(carry):
i, y_array, y_cur = carry
y_next = y_cur + dt * -y_cur
return i+1, y_array.at[i+1].set(y_next), y_next
init_val = (0, jnp.zeros(N+1), 1.0)
_, y_sol, _ = jax.lax.while_loop(cond_fun, body_fun, init_val)
def make_solve(loop_type, vmap=False):
def solve(args):
#return jax.lax.while_loop(cond_fun, body_fun, args)[1]
return eqxi.while_loop(cond_fun, body_fun, args, kind=loop_type, max_steps=4096)[1]
if vmap:
return eqx.filter_jit(eqx.filter_vmap(solve, in_axes=((None, 0, None),)))
else:
return eqx.filter_jit(solve)
direct_solve = make_solve("bounded")
recursive_solve = make_solve("checkpointed")
vmap_direct_solve = make_solve("bounded", vmap=True)
vmap_recursive_solve = make_solve("checkpointed", vmap=True)
_ = direct_solve(init_val).block_until_ready()
_ = recursive_solve(init_val).block_until_ready()
init_val_vmap = (0, jnp.zeros((800, N+1)), 1.0)
_ = vmap_direct_solve(init_val_vmap).block_until_ready()
_ = vmap_recursive_solve(init_val_vmap).block_until_ready()
print("Bounded while single")
%timeit direct_solve(init_val).block_until_ready()
print("Checkpoint while single")
%timeit recursive_solve(init_val).block_until_ready()
print("Bounded while vmap")
%timeit vmap_direct_solve(init_val_vmap).block_until_ready()
print("Checkpoint while vmap")
%timeit vmap_recursive_solve(init_val_vmap).block_until_ready()
Bounded while single
112 µs ± 2.08 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Checkpoint while single
112 µs ± 1.81 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Bounded while vmap
151 µs ± 2.14 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
Checkpoint while vmap
118 µs ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
One thing that might be causing a difference between @johannahaffner and mine simpler recreations and the original code is the adaptivity. I tested without adaptivity
Code + Results
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import math
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax # https://github.com/deepmind/optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
data_size: int
hidden_size: int
def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.data_size = data_size
self.hidden_size = hidden_size
self.mlp = eqx.nn.MLP(
in_size=hidden_size,
out_size=hidden_size * data_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
# Note the use of a tanh final activation function. This is important to
# stop the model blowing up. (Just like how GRUs and LSTMs constrain the
# rate of change of their hidden states.)
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
linear: eqx.nn.Linear
adjoint_state: int
def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
self.adjoint_state = adjoint_state
def __call__(self, ts, coeffs):
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = ts[-1] / 100
if self.adjoint_state == 0:
adjoint = diffrax.RecursiveCheckpointAdjoint()
else:
adjoint = diffrax.DirectAdjoint()
y0 = self.initial(control.evaluate(ts[0]))
saveat = diffrax.SaveAt(t1=True)
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
# stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=saveat,
adjoint=adjoint,
)
(prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
return prediction
def get_data(dataset_size, add_noise, *, key):
theta_key, noise_key = jr.split(key, 2)
length = 100
theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
ys = jax.vmap(
lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
)(y0, ts)
ys = jnp.concatenate([ts[:, :, None], ys], axis=-1) # time is a channel
ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
if add_noise:
ys = ys + jr.normal(noise_key, ys.shape) * 0.1
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
labels = jnp.zeros((dataset_size,))
labels = labels.at[: dataset_size // 2].set(1.0)
_, _, data_size = ys.shape
return ts, coeffs, labels, data_size
def main(
dataset_size=256,
add_noise=False,
batch_size=32,
lr=1e-2,
steps=20,
hidden_size=8,
width_size=128,
depth=1,
seed=5678,
):
key = jr.PRNGKey(seed)
train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
ts, coeffs, labels, data_size = get_data(
dataset_size, add_noise, key=train_data_key
)
@eqx.filter_jit
def loss(model, ti, label_i, coeff_i):
pred = jax.vmap(model)(ti, coeff_i)
bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
bxe = -jnp.mean(bxe)
acc = jnp.mean((pred > 0.5) == (label_i == 1))
return bxe, acc
grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
@eqx.filter_jit
def make_step(model, data_i, opt_state):
ti, label_i, *coeff_i = data_i
(bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
updates, opt_state = optim.update(grads, opt_state)
model = eqx.apply_updates(model, updates)
return bxe, acc, model, opt_state
optim = optax.adam(lr)
data_all = (ts, labels) + coeffs
data_i = tuple(a[:batch_size] for a in data_all)
print('recursive checkpoint adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
print('direct adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main()
and it didn't seem to make much of a difference. Next I thought, maybe size/code of VF? Since neural network is much more compute intensive then our two toy examples. You do see some time difference here (checkpoint is slower than bounded which it wasn't before, but not by much, idk if that's expected).
Code + Results
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.nn as jnn
import jax.numpy as jnp
print(jax.__version__)
print(eqx.__version__)
key = jax.random.PRNGKey(0)
model = eqx.nn.MLP(
1,
1,
128,
8,
activation=jnn.softplus,
final_activation=jnn.tanh,
key=key,
)
t0 = 0.0
t1 = 30.0
dt = 0.1
N = int((t1 - t0) / dt)
t = jnp.arange(t0, t1 + dt, dt)
def cond_fun(carry):
i, y_array, y_cur = carry
return i < N
def body_fun(carry):
i, y_array, y_cur = carry
y_next = y_cur + dt * model(jnp.array([y_cur]))[0]
return i + 1, y_array.at[i+1].set(y_next), y_next
init_val = (jnp.zeros(N+1), jnp.array(1.0))
def make_solve(loop_type, vmap=False):
def solve(args):
args = (0, *args)
max_steps = 4096
return eqxi.while_loop(cond_fun, body_fun, args, kind=loop_type, max_steps=max_steps)[1]
if vmap:
return eqx.filter_jit(eqx.filter_vmap(solve))
else:
return eqx.filter_jit(solve)
direct_solve = make_solve("bounded")
recursive_solve = make_solve("checkpointed")
vmap_direct_solve = make_solve("bounded", vmap=True)
vmap_recursive_solve = make_solve("checkpointed", vmap=True)
_ = direct_solve(init_val).block_until_ready()
_ = recursive_solve(init_val).block_until_ready()
init_val_vmap = (jnp.zeros((32, N+1)), jnp.ones(32))
_ = vmap_direct_solve(init_val_vmap).block_until_ready()
_ = vmap_recursive_solve(init_val_vmap).block_until_ready()
print("Bounded while single")
%timeit direct_solve(init_val).block_until_ready()
print("Checkpoint while single")
%timeit recursive_solve(init_val).block_until_ready()
print("Bounded while vmap")
%timeit vmap_direct_solve(init_val_vmap).block_until_ready()
print("Checkpoint while vmap")
%timeit vmap_recursive_solve(init_val_vmap).block_until_ready()
0.4.38
0.11.11
Bounded while single
3.57 ms ± 106 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Checkpoint while single
3.6 ms ± 110 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Bounded while vmap
73.2 ms ± 259 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Checkpoint while vmap
78.8 ms ± 1.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Then lastly, I thought maybe it isn't the loop that's slow but the gradient (since that's also a difference, we just did inference). So if you remove the gradient calculation, you see both loops are the same speed. (again, not sure what it totally expected, but hopefully this provides some more useful datapoints).
Code + results
import os
os.environ["XLA_FLAGS"] = "--xla_cpu_use_thunk_runtime=false"
import math
import diffrax
import equinox as eqx # https://github.com/patrick-kidger/equinox
import jax
import jax.nn as jnn
import jax.numpy as jnp
import jax.random as jr
import jax.scipy as jsp
import optax # https://github.com/deepmind/optax
class Func(eqx.Module):
mlp: eqx.nn.MLP
data_size: int
hidden_size: int
def __init__(self, data_size, hidden_size, width_size, depth, *, key, **kwargs):
super().__init__(**kwargs)
self.data_size = data_size
self.hidden_size = hidden_size
self.mlp = eqx.nn.MLP(
in_size=hidden_size,
out_size=hidden_size * data_size,
width_size=width_size,
depth=depth,
activation=jnn.softplus,
# Note the use of a tanh final activation function. This is important to
# stop the model blowing up. (Just like how GRUs and LSTMs constrain the
# rate of change of their hidden states.)
final_activation=jnn.tanh,
key=key,
)
def __call__(self, t, y, args):
return self.mlp(y).reshape(self.hidden_size, self.data_size)
class NeuralCDE(eqx.Module):
initial: eqx.nn.MLP
func: Func
linear: eqx.nn.Linear
adjoint_state: int
def __init__(self, data_size, hidden_size, width_size, depth,adjoint_state, *, key, **kwargs):
super().__init__(**kwargs)
ikey, fkey, lkey = jr.split(key, 3)
self.initial = eqx.nn.MLP(data_size, hidden_size, width_size, depth, key=ikey)
self.func = Func(data_size, hidden_size, width_size, depth, key=fkey)
self.linear = eqx.nn.Linear(hidden_size, 1, key=lkey)
self.adjoint_state = adjoint_state
def __call__(self, ts, coeffs):
control = diffrax.CubicInterpolation(ts, coeffs)
term = diffrax.ControlTerm(self.func, control).to_ode()
solver = diffrax.Tsit5()
dt0 = ts[-1] / 100
if self.adjoint_state == 0:
adjoint = diffrax.RecursiveCheckpointAdjoint()
else:
adjoint = diffrax.DirectAdjoint()
y0 = self.initial(control.evaluate(ts[0]))
saveat = diffrax.SaveAt(t1=True)
solution = diffrax.diffeqsolve(
term,
solver,
ts[0],
ts[-1],
dt0,
y0,
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
saveat=saveat,
adjoint=adjoint,
)
(prediction,) = jnn.sigmoid(self.linear(solution.ys[-1]))
return prediction
def get_data(dataset_size, add_noise, *, key):
theta_key, noise_key = jr.split(key, 2)
length = 100
theta = jr.uniform(theta_key, (dataset_size,), minval=0, maxval=2 * math.pi)
y0 = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=-1)
ts = jnp.broadcast_to(jnp.linspace(0, 4 * math.pi, length), (dataset_size, length))
matrix = jnp.array([[-0.3, 2], [-2, -0.3]])
ys = jax.vmap(
lambda y0i, ti: jax.vmap(lambda tij: jsp.linalg.expm(tij * matrix) @ y0i)(ti)
)(y0, ts)
ys = jnp.concatenate([ts[:, :, None], ys], axis=-1) # time is a channel
ys = ys.at[: dataset_size // 2, :, 1].multiply(-1)
if add_noise:
ys = ys + jr.normal(noise_key, ys.shape) * 0.1
coeffs = jax.vmap(diffrax.backward_hermite_coefficients)(ts, ys)
labels = jnp.zeros((dataset_size,))
labels = labels.at[: dataset_size // 2].set(1.0)
_, _, data_size = ys.shape
return ts, coeffs, labels, data_size
def main(
dataset_size=256,
add_noise=False,
batch_size=32,
lr=1e-2,
steps=20,
hidden_size=8,
width_size=128,
depth=1,
seed=5678,
):
key = jr.PRNGKey(seed)
train_data_key, test_data_key, model_key, loader_key = jr.split(key, 4)
ts, coeffs, labels, data_size = get_data(
dataset_size, add_noise, key=train_data_key
)
@eqx.filter_jit
def loss(model, ti, label_i, coeff_i):
pred = jax.vmap(model)(ti, coeff_i)
bxe = label_i * jnp.log(pred) + (1 - label_i) * jnp.log(1 - pred)
bxe = -jnp.mean(bxe)
acc = jnp.mean((pred > 0.5) == (label_i == 1))
return bxe, acc
grad_loss = eqx.filter_value_and_grad(loss, has_aux=True)
@eqx.filter_jit
def make_step(model, data_i, opt_state):
ti, label_i, *coeff_i = data_i
# (bxe, acc), grads = grad_loss(model, ti, label_i, coeff_i)
bxe, acc = loss(model, ti, label_i, coeff_i)
return bxe, acc, model, opt_state
# updates, opt_state = optim.update(grads, opt_state)
# model = eqx.apply_updates(model, updates)
# return bxe, acc, model, opt_state
optim = optax.adam(lr)
data_all = (ts, labels) + coeffs
data_i = tuple(a[:batch_size] for a in data_all)
print('recursive checkpoint adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 0, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
print('direct adjoint')
model = NeuralCDE(data_size, hidden_size, width_size, depth,adjoint_state = 1, key=model_key)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))
bxe, acc, model, opt_state = jax.block_until_ready(make_step(model, data_i, opt_state))
%timeit jax.block_until_ready(make_step(model, data_i, opt_state))
main()
recursive checkpoint adjoint
18.8 ms ± 211 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
direct adjoint
18.6 ms ± 40.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
That looks really thorough, @lockwo!
I was also looking into the adaptivity, but only got something very preliminary. I could also not immediately confirm this by switching to constant step sizes.
For a small batch and data set size (8 and 32), it looks like more steps are taken with the recursive adjoint. Since the number of steps passes a power of 2 at 64, maybe this is a memory allocation thing? I won't have time to look at this in more detail before the weekend though, so all I can do is leaving you with this very ugly plot.
Since you touch on the magic number 64: FWIW DirectAdjoint is implemented in terms of nested jax.lax.scans of length 8, and therefore starts getting different speed/memory (but hopefully not steps) past every power of 8.
But if we're getting different numbers of steps between the adjoints then something has gone very wrong! These should be identical in terms of the mathematics they compute.
Do you have MVC for the different steps? I will say vmap of double while loops (like with adaptive stepping) I have encountered some interesting edge cases with before, but my minimal example seems to work as expected.
Code
import diffrax as dfx
import equinox as eqx
import jax.numpy as jnp
import jax
def vector_field(t, y, args):
prey, predator = y[0], y[1]
α, β, γ, δ = (0.1, 0.02, 0.4, 0.02)
d_prey = α * prey - β * prey * predator
d_predator = -γ * predator + δ * prey * predator
d_y = jnp.array([d_prey, d_predator])
return d_y
term = dfx.ODETerm(vector_field)
solver = dfx.Dopri5()
def make_solve(term, solver, adjoint, vmap=False):
def solve(init):
return dfx.diffeqsolve(term, solver, 0, 200, 0.05, init,
None, adjoint=adjoint, stepsize_controller=dfx.PIDController(1e-3, 1e-6))
if vmap:
return eqx.filter_jit(eqx.filter_vmap(solve))
else:
return eqx.filter_jit(solve)
direct = dfx.DirectAdjoint()
recursive = dfx.RecursiveCheckpointAdjoint()
direct_solve = make_solve(term, solver, direct)
recursive_solve = make_solve(term, solver, recursive)
vmap_direct_solve = make_solve(term, solver, direct, vmap=True)
vmap_recursive_solve = make_solve(term, solver, recursive, vmap=True)
key = jax.random.key(0)
key, subkey = jax.random.split(key)
inits = jax.random.uniform(subkey, (2,))
_ = direct_solve(inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
_ = recursive_solve(inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
key, subkey = jax.random.split(key)
vmap_inits = jax.random.uniform(subkey, (10, 2))
_ = vmap_direct_solve(vmap_inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
_ = vmap_recursive_solve(vmap_inits)
print(_.stats["num_accepted_steps"], _.stats["num_rejected_steps"], _.stats["num_steps"])
# print("Timing direct adjoint: quadratic path")
# %timeit direct_solve(None).block_until_ready()
# print("Timing recursive checkpoint adjoint: quadratic path")
# %timeit recursive_solve(None).block_until_ready()
# print("Timing direct adjoint: quadratic path (vmap)")
# %timeit vmap_direct_solve(dummy_args).block_until_ready()
# print("Timing recursive checkpoint adjoint: quadratic path (vmap)")
# %timeit vmap_recursive_solve(dummy_args).block_until_ready()
The plot above is parsed output from jax.debug.print statements placed in NeuralCDE.__call__ 😅
Very preliminary, and not rigorous!
@johannahaffner , does your example here involve a gradient calculation? If so then that will explain the discrepancy in number of steps.
Whilst they'll both take the exact same number of steps in the diffeq solve (as I think @lockwo 's example demonstrates), they do backpropagate in slightly different ways, by design, with DirectAdjoint being asymptotically less efficient in number of vector field evaluations (which is why I don't recommend it).
It does, it runs %timeit make_step which does call grad_loss. Makes sense that this would differ!