VBT vs brownian path slowdown
While implementing some weak solver schemes, I noticed that when I used the VBT as opposed to the UBP (unsafe brownian path) it was substantially (like ~5 mins vs ~10 seconds) slower. Using an UBP is fine in this case for us (since its a fixed step solver, and we aren't differentiating through the equation), but in the future is not ideal. Below is a MVC, but in summary:
- This is not the real solver, I ripped out most of everything just to make the code smaller
- These numbers seem a bit small/microbenchmark-y but I have verified them on some larger problems as well, I am just using this small problem for speed and demonstration.
- VBT is 10x slower here, and seems dominated by the line
u += g1 @ (_dW + chi1). Specifically, if I comment that out, I see a decrease in speed from ~11s to 3s. But in UBP it only goes from like 1.1s to 0.8s (so it isn't just the elimination of a matmul making that whole speed gap). - The surprising thing is not that VBT is slower (I figured it would come with some overhead), but that it seems to scale as well. Specifically, if I decrease dt, it's not just some constant overhead but seems to increase. Maybe this is expected, but even for small problems with large dts this becomes prohibitive (see the original 5 min vs 10 seconds).
All of this is a bit surprising since I just call the diffusion control once. Is there a way of using VBT's or integrating them into new solvers that avoids this slowdown, or am I just making some mistake in my usage of the VBT?
Here is the full code:
import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar
_NORMAL_ONESIX_QUANTILE = -0.9674215661017014
def calc_threepoint_random(x):
return jnp.where(
jnp.abs(x) > -_NORMAL_ONESIX_QUANTILE,
jnp.where(x < _NORMAL_ONESIX_QUANTILE, -1.0, 1.0),
0.0,
)
def calc_twopoint_random(x):
return jnp.where(x > 0, 1.0, -1.0)
class Solver(diffrax.AbstractSolver):
term_structure: ClassVar = diffrax.AbstractTerm
interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation
def func(self, terms, t0, y0, args):
return terms.vf(t0, y0, args)
def init(self, terms, t0, t1, y0, args):
return None
def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
drift = terms.terms[0]
diffusion = terms.terms[1]
cont = diffusion.contr(t0, t1)
dt = t1 - t0
dW_scaled = cont["dW"] / jnp.sqrt(dt)
sq3dt = jnp.sqrt(3 * dt)
_dW = sq3dt * calc_threepoint_random(dW_scaled)
dZ_scaled = cont["dZ"]
_dZ = calc_twopoint_random(dZ_scaled)
xi = jnp.sqrt(dt) * _dZ[0]
chi1 = (_dW**2 / xi - xi) / 2
k1 = drift.vf(t0, y0, args)
g1 = diffusion.vf(t0, y0, args)
H02 = y0 + k1 * dt + g1 @ _dW
k2 = drift.vf(t0, H02, args)
H03 = y0 + k2 * dt + k1 * dt + g1 @ _dW
k3 = drift.vf(t0, H03, args)
u = y0 + k1 * dt + k2 * dt + k3 * dt
u += g1 @ (_dW + chi1)
dense_info = dict(y0=y0, y1=u)
return u, None, dense_info, None, diffrax.RESULTS.successful
def drift(t, X, args):
y1, y2 = X
dy1 = -273 / 512 * y1
dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
return jnp.array([dy1, dy2])
def diffusion(t, X, args):
y1, y2 = X
g11 = 1 / 4 * y1
g12 = 1 / 16 * y1
g21 = (1 - 2 * jnp.sqrt(2)) / 4 * y1
g22 = 1 // 10 * y1 + 1 // 16 * y2
return jnp.array([[g11, g12], [g21, g22]])
t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])
def solve_wrapper(dt, num_samples, use_tree):
keys = jax.random.split(jax.random.key(42), num_samples)
solver = Solver()
saveat = diffrax.SaveAt(t1=True)
def solve(key):
if not use_tree:
tree = diffrax.UnsafeBrownianPath(
shape={
"dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
"dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
},
key=key,
)
terms = diffrax.MultiTerm(
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
)
return diffrax.diffeqsolve(
terms,
solver,
t0,
t1,
dt0=dt,
y0=y0,
saveat=saveat,
adjoint=diffrax.DirectAdjoint(),
)
else:
tree = diffrax.VirtualBrownianTree(
t0,
t1,
tol=dt / 2,
shape={
"dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
"dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
},
key=key,
)
terms = diffrax.MultiTerm(
diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
)
return diffrax.diffeqsolve(
terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat
)
# , adjoint=diffrax.DirectAdjoint()) this is 2x slower
return jax.jit(jax.vmap(solve))(keys).ys.squeeze(axis=1)
%%timeit
_ = solve_wrapper(1.0, 20 * 100_000, True).block_until_ready()
yields 9.37 s ± 281 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) for VBT and 1.03 s ± 8.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) for UBP.
Maybe this isn't even solver specific, I just noticed using Heun instead, with dt = 0.1, VBT takes 27s but UBP takes3s , so maybe the real question is twofold: is there anyway to make VBT faster, and if not, is there any way to program around VBT so we can making adaptive stepping solvers, or differentiable solvers, assuming we can give maintain one of the three pillars of the UBP requirements?
So VBT is sensitive to the tolerance used: decreasing the tolerances increases the computational cost. This is the reason you're seeing scaling with dt.
Other than that I think there is performance being left on the table with our VBT implementation:
- on every sample is goes through its entire while loop again, despite the fact that most of those iterations are probably identical to the last sample.
- across two adjacent steps
[t0, t1],[t1, t2], then we end up evaluatingvbt(t1) - vbt(t0)followed byvbt(t2) - vbt(t1)-- so that we actually evaluatevbt(t1)twice! This doubles the computational work.
These are things that I think will require some careful thought to fix, so they've never made it far enough up my to-do list. (What I really want is to an LRU cache on the evalutions of the loop body.) FWIW I did benchmark all of this when I first wrote all of this, and found that whilst there was a slowdown, it wasn't as dramatic as you're seeing here. It might be that you're in a case this is particularly pronounced, or something might have sneakily regressed without me noticing...
(Semi-relatedly we also have this benchmark: https://github.com/patrick-kidger/diffrax/blob/main/benchmarks/brownian_tree_times.py)
Another thing I'd love to see an implementation of some time is the Brownian Interval from this paper, but again that's fairly fiddly in JAX's model of computation. (Not impossible though I think.)
When I compare the old to the new VBT, I didn't see a pronounced slowdown, so it doesn't seem like regression (although I am generally supportive of speed regression tests). But one thing I did notice, if if I make the (new) VBT over an array, rather than a pytree, e.g.
# shape=jax.ShapeDtypeStruct((4,), dtype=jnp.float64),
shape={
"dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
"dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
},
and just parse the array and convert to dict in solver
if not isinstance(cont, dict):
cont = {"dW": cont[:2], "dZ": cont[2:]}
this is >2x faster than if I just have the VBT over a dictionary (from 11s to 4s). Maybe this is expected since the tree is iterated over (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_brownian/tree.py#L323) and tree maps happen sequentially (I assume?). If this is the case, is there a reason not to just flatten? I could just be misreading it.
I'm starting to have a lurking suspicion we will end up implementing a brownian path object, then diff check it against VBT to understand it better lol.
I think we can definitely squeeze more performance from the VBT (and we can add those to the list of things to implement), it just feels like the slowdown shouldn't be this much even with a slightly suboptimal VBT.
Another thing I'd love to see an implementation of some time is the Brownian Interval from this paper, but again that's fairly fiddly in JAX's model of computation. (Not impossible though I think.)
Was there ever an implementation in something like torchsde?
I think the iteration over tree leaves is indeed a mistake performance-wise. FWIW in principle JAX should be able to parallelize each call, but in practice it seems that it is not doing that...
I think you're right, the better approach would be to have a single loop that acts over a PyTree.
I'll tag @andyElking on this issue too.
Was there ever an implementation in something like torchsde?
Yup. torchsde is where I wrote the original canonical Brownian Interval implementation.
(In practice the implementation is really rather complicated, and has a couple of footguns of its own!)
I completely agree. In fact I've been eyeing that split by PyTree for a while now and am intending to refactor it soonish. In addition I am intending to add a LRU cache to _evaluate, but that can be a separate edit since it also requires changes to diffeqsolve.
FWIW in principle JAX should be able to parallelize each call, but in practice it seems that it is not doing that...
Yea, it's weird since the path is (probably) parallelizing it over the same pytree (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_brownian/path.py#L126). Not sure why the inner map of VBT isn't.
(In practice the implementation is really rather complicated, and has a couple of footguns of its own!)
Is the reason for implementing the VBT over the BI just less foot guns? Or easier in jax?
Primarily just that the VBT is easier in JAX. The BI algorithm involves dynamically creating a tree, so to do that in JAX you'd have to preallocate a buffer and then use that to store pointers (indices) into itself.