diffrax icon indicating copy to clipboard operation
diffrax copied to clipboard

VBT vs brownian path slowdown

Open lockwo opened this issue 1 year ago • 7 comments

While implementing some weak solver schemes, I noticed that when I used the VBT as opposed to the UBP (unsafe brownian path) it was substantially (like ~5 mins vs ~10 seconds) slower. Using an UBP is fine in this case for us (since its a fixed step solver, and we aren't differentiating through the equation), but in the future is not ideal. Below is a MVC, but in summary:

  • This is not the real solver, I ripped out most of everything just to make the code smaller
  • These numbers seem a bit small/microbenchmark-y but I have verified them on some larger problems as well, I am just using this small problem for speed and demonstration.
  • VBT is 10x slower here, and seems dominated by the line u += g1 @ (_dW + chi1). Specifically, if I comment that out, I see a decrease in speed from ~11s to 3s. But in UBP it only goes from like 1.1s to 0.8s (so it isn't just the elimination of a matmul making that whole speed gap).
  • The surprising thing is not that VBT is slower (I figured it would come with some overhead), but that it seems to scale as well. Specifically, if I decrease dt, it's not just some constant overhead but seems to increase. Maybe this is expected, but even for small problems with large dts this becomes prohibitive (see the original 5 min vs 10 seconds).

All of this is a bit surprising since I just call the diffusion control once. Is there a way of using VBT's or integrating them into new solvers that avoids this slowdown, or am I just making some mistake in my usage of the VBT?

Here is the full code:

import jax
from jax import numpy as jnp
import diffrax
from typing import ClassVar

_NORMAL_ONESIX_QUANTILE = -0.9674215661017014


def calc_threepoint_random(x):
    return jnp.where(
        jnp.abs(x) > -_NORMAL_ONESIX_QUANTILE,
        jnp.where(x < _NORMAL_ONESIX_QUANTILE, -1.0, 1.0),
        0.0,
    )


def calc_twopoint_random(x):
    return jnp.where(x > 0, 1.0, -1.0)


class Solver(diffrax.AbstractSolver):

    term_structure: ClassVar = diffrax.AbstractTerm
    interpolation_cls: ClassVar = diffrax.LocalLinearInterpolation

    def func(self, terms, t0, y0, args):
        return terms.vf(t0, y0, args)

    def init(self, terms, t0, t1, y0, args):
        return None

    def step(self, terms, t0, t1, y0, args, solver_state, made_jump):
        drift = terms.terms[0]
        diffusion = terms.terms[1]
        cont = diffusion.contr(t0, t1)
        dt = t1 - t0
        dW_scaled = cont["dW"] / jnp.sqrt(dt)
        sq3dt = jnp.sqrt(3 * dt)
        _dW = sq3dt * calc_threepoint_random(dW_scaled)
        dZ_scaled = cont["dZ"]
        _dZ = calc_twopoint_random(dZ_scaled)
        xi = jnp.sqrt(dt) * _dZ[0]
        chi1 = (_dW**2 / xi - xi) / 2
        k1 = drift.vf(t0, y0, args)
        g1 = diffusion.vf(t0, y0, args)
        H02 = y0 + k1 * dt + g1 @ _dW
        k2 = drift.vf(t0, H02, args)
        H03 = y0 + k2 * dt + k1 * dt + g1 @ _dW
        k3 = drift.vf(t0, H03, args)
        u = y0 + k1 * dt + k2 * dt + k3 * dt
        u += g1 @ (_dW + chi1)
        dense_info = dict(y0=y0, y1=u)

        return u, None, dense_info, None, diffrax.RESULTS.successful

def drift(t, X, args):
    y1, y2 = X
    dy1 = -273 / 512 * y1
    dy2 = -1 // 160 * y1 - (-785 // 512 + jnp.sqrt(2) / 8) * y2
    return jnp.array([dy1, dy2])


def diffusion(t, X, args):
    y1, y2 = X
    g11 = 1 / 4 * y1
    g12 = 1 / 16 * y1
    g21 = (1 - 2 * jnp.sqrt(2)) / 4 * y1
    g22 = 1 // 10 * y1 + 1 // 16 * y2

    return jnp.array([[g11, g12], [g21, g22]])


t0, t1 = 0.0, 3.0
y0 = jnp.array([1.0, 1.0])


def solve_wrapper(dt, num_samples, use_tree):
    keys = jax.random.split(jax.random.key(42), num_samples)
    solver = Solver()
    saveat = diffrax.SaveAt(t1=True)

    def solve(key):
        if not use_tree:
            tree = diffrax.UnsafeBrownianPath(
                shape={
                    "dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                    "dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                },
                key=key,
            )
            terms = diffrax.MultiTerm(
                diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
            )
            return diffrax.diffeqsolve(
                terms,
                solver,
                t0,
                t1,
                dt0=dt,
                y0=y0,
                saveat=saveat,
                adjoint=diffrax.DirectAdjoint(),
            )
        else:
            tree = diffrax.VirtualBrownianTree(
                t0,
                t1,
                tol=dt / 2,
                shape={
                    "dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                    "dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                },
                key=key,
            )
            terms = diffrax.MultiTerm(
                diffrax.ODETerm(drift), diffrax.ControlTerm(diffusion, tree)
            )
            return diffrax.diffeqsolve(
                terms, solver, t0, t1, dt0=dt, y0=y0, saveat=saveat
            )
            # , adjoint=diffrax.DirectAdjoint()) this is 2x slower

    return jax.jit(jax.vmap(solve))(keys).ys.squeeze(axis=1)
%%timeit
_ = solve_wrapper(1.0, 20 * 100_000, True).block_until_ready()

yields 9.37 s ± 281 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) for VBT and 1.03 s ± 8.17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) for UBP.

lockwo avatar Aug 16 '24 20:08 lockwo

Maybe this isn't even solver specific, I just noticed using Heun instead, with dt = 0.1, VBT takes 27s but UBP takes3s , so maybe the real question is twofold: is there anyway to make VBT faster, and if not, is there any way to program around VBT so we can making adaptive stepping solvers, or differentiable solvers, assuming we can give maintain one of the three pillars of the UBP requirements?

lockwo avatar Aug 16 '24 20:08 lockwo

So VBT is sensitive to the tolerance used: decreasing the tolerances increases the computational cost. This is the reason you're seeing scaling with dt.

Other than that I think there is performance being left on the table with our VBT implementation:

  • on every sample is goes through its entire while loop again, despite the fact that most of those iterations are probably identical to the last sample.
  • across two adjacent steps [t0, t1], [t1, t2], then we end up evaluating vbt(t1) - vbt(t0) followed by vbt(t2) - vbt(t1) -- so that we actually evaluate vbt(t1) twice! This doubles the computational work.

These are things that I think will require some careful thought to fix, so they've never made it far enough up my to-do list. (What I really want is to an LRU cache on the evalutions of the loop body.) FWIW I did benchmark all of this when I first wrote all of this, and found that whilst there was a slowdown, it wasn't as dramatic as you're seeing here. It might be that you're in a case this is particularly pronounced, or something might have sneakily regressed without me noticing...

(Semi-relatedly we also have this benchmark: https://github.com/patrick-kidger/diffrax/blob/main/benchmarks/brownian_tree_times.py)

Another thing I'd love to see an implementation of some time is the Brownian Interval from this paper, but again that's fairly fiddly in JAX's model of computation. (Not impossible though I think.)

patrick-kidger avatar Aug 17 '24 06:08 patrick-kidger

When I compare the old to the new VBT, I didn't see a pronounced slowdown, so it doesn't seem like regression (although I am generally supportive of speed regression tests). But one thing I did notice, if if I make the (new) VBT over an array, rather than a pytree, e.g.

              # shape=jax.ShapeDtypeStruct((4,), dtype=jnp.float64),
              shape={
                  "dW": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
                  "dZ": jax.ShapeDtypeStruct((2,), dtype=jnp.float64),
              },

and just parse the array and convert to dict in solver

if not isinstance(cont, dict):
     cont = {"dW": cont[:2], "dZ": cont[2:]}

this is >2x faster than if I just have the VBT over a dictionary (from 11s to 4s). Maybe this is expected since the tree is iterated over (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_brownian/tree.py#L323) and tree maps happen sequentially (I assume?). If this is the case, is there a reason not to just flatten? I could just be misreading it.

I'm starting to have a lurking suspicion we will end up implementing a brownian path object, then diff check it against VBT to understand it better lol.

I think we can definitely squeeze more performance from the VBT (and we can add those to the list of things to implement), it just feels like the slowdown shouldn't be this much even with a slightly suboptimal VBT.

Another thing I'd love to see an implementation of some time is the Brownian Interval from this paper, but again that's fairly fiddly in JAX's model of computation. (Not impossible though I think.)

Was there ever an implementation in something like torchsde?

lockwo avatar Aug 17 '24 07:08 lockwo

I think the iteration over tree leaves is indeed a mistake performance-wise. FWIW in principle JAX should be able to parallelize each call, but in practice it seems that it is not doing that...

I think you're right, the better approach would be to have a single loop that acts over a PyTree.

I'll tag @andyElking on this issue too.

Was there ever an implementation in something like torchsde?

Yup. torchsde is where I wrote the original canonical Brownian Interval implementation.

(In practice the implementation is really rather complicated, and has a couple of footguns of its own!)

patrick-kidger avatar Aug 17 '24 08:08 patrick-kidger

I completely agree. In fact I've been eyeing that split by PyTree for a while now and am intending to refactor it soonish. In addition I am intending to add a LRU cache to _evaluate, but that can be a separate edit since it also requires changes to diffeqsolve.

andyElking avatar Aug 17 '24 10:08 andyElking

FWIW in principle JAX should be able to parallelize each call, but in practice it seems that it is not doing that...

Yea, it's weird since the path is (probably) parallelizing it over the same pytree (https://github.com/patrick-kidger/diffrax/blob/main/diffrax/_brownian/path.py#L126). Not sure why the inner map of VBT isn't.

(In practice the implementation is really rather complicated, and has a couple of footguns of its own!)

Is the reason for implementing the VBT over the BI just less foot guns? Or easier in jax?

lockwo avatar Aug 21 '24 06:08 lockwo

Primarily just that the VBT is easier in JAX. The BI algorithm involves dynamically creating a tree, so to do that in JAX you'd have to preallocate a buffer and then use that to store pointers (indices) into itself.

patrick-kidger avatar Aug 21 '24 06:08 patrick-kidger