lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Tridiagonal solve significantly slower than jax in-built when unbatched and for small batch sizes

Open jpbrodrick89 opened this issue 9 months ago • 6 comments

The current implementation of tridiagonal solve in lineax uses a manually implemented version of the Thomas algorithm with jax.lax.scan and a hard-coded unroll of 32. There is now a jax.linalg.tridiagonal_solve which directly calls the respective cusparse and lapack libraries when on GPU/CPU. JVP and transpose rules are supported on all devices now. However, despited the interface appearing batch friendly, timing scales linearly with batch size, whereas one would hope for sub-linear scaling on GPU. The latter is indeed achieved for the lineax implementation but the default unroll provides a sub-optimal starting point.

I have run tests for a simple tridiagonal solve of various arrays sizes (200–204,800) and it's VJP as well as various batch sizes for a fixed array size of 200. My main conclusions are

  • Current (unroll=32) lineax implementation of forward solve is about 5x slower than jax's respective implementation on same device.
  • Current (unroll=32) lineax implementation becomes more efficient for forward solves when batch size is greater than ~20 (CPU) and ~5 (GPU).
  • Current (unroll=32) lineax implementation of VJP is about 10x slower than jax on CPU and more than 200x slower on GPU.
  • Fully rolled (unroll=1) lineax implementation of forward solve is comparable/slightly faster thsn jax lapack implementation on CPU.
  • Fully rolled (unroll=1) lineax implementation of forward solve is faster than jax lapack implementation for batch sizes of 5 or more on CPU.
  • Fully rolled (unroll=1) lineax implementation of VJP is still up to 5x slower for VJP.
  • Fully unrolled (unroll=200 initially) lineax is similar to jax on GPU (compilation time a lot longer of course).
  • unroll=200 is more efficient than jax for all batch sizes on GPU.
  • Both rolled and fully unrolled implementations of VJP are about 100x slower on GPU.

I am quite confident my forward solve timings are fair and representative, but there is a chance my VJP timings might be off.

Options:

  1. Expose setting of unroll to user with default value of 1 on CPU and 32 on GPU (maybe fully unrolled if possible but users might be surprised about long compilation time) and look into why VJP is running slower.
  2. Just use jax.linalg.tridiagonal_solve and accept slow batch timings.
  3. A combination of 1 and 2 where we use 1 for forward solves and 2 for VJP (and probably JVP too)

Happy to help make the PR once we decide on solution. Thanks!

Image

Image

Image

Tests run on a single Mac M2 CPU core and an NVIDIA A100 GPU on the NERSC Perlmutter supercomputer. To be specific, the problem solves the heat equation for a tanh temperature profile in a plasma with a thermal conductivity scaling as T**2.5 and a relatively small timestep. The VJP is with respect to temperature with a Gaussian co-tangent vector.

Happy to share timing scripts but they required allowing customisation of unroll on my lineax fork and are also somewhat verbose as I wanted to ensure the matrix was well-posed and generalised well to different array sizes. Nevertheless, I confirm I

  • [x] Passed the diagonals as jax arrays across the jit boundary (operator created after)
  • [x] Used throw=False for fair comparison
  • [x] Used block_until_ready for representative timing.
  • [x] Called each function once with each set of input shapes before timing to eliminate compilation time.

jpbrodrick89 avatar Apr 14 '25 16:04 jpbrodrick89

Thanks for opening the issue! The thoroughness here is awesome -- in particular for listing all the common gotchas on the benchmarks.

I'm surprised by this combination of outcomes:

Current (unroll=32) lineax implementation of forward solve is about 5x slower than jax's respective implementation on same device. Fully rolled (unroll=1) lineax implementation of forward solve is comparable/slightly faster thsn jax lapack implementation on CPU.

(and the corresponding entries in the graphs) which indicates that our unrolling actually hinders performance. Which version of JAX/jaxlib is this on?

In terms of what we can do here, then it's worth knowing that this section of Lineax is only ever called under jax.jit or jax.vmap. In particular it's never actually differentiated directly (to be precise it never sees a jvp trace), as this is handled at an outer level by lineax.linear_solve, which automatically builds autodifferentiation and transposition out of just the forward linear solve. (I'm not sure why JAX doesn't also do the same thing, as it's totally possible to do.)

Bearing that in mind and taking a look at these graphs, it seems to me that the desired heuristics are probably something like:

  • If on CPU then use our current implementation and unroll=1.
  • If on GPU and batch size is small, then use the jax.linalg.tridiagonal_solve implementation.
  • If on GPU and batch size is large, then use our current implementation and unroll=32.

I think if we wanted to express this then we'd actually need to use a custom primitive, as I think dispatching on either batch size or on device require the use of custom primitives.

Alternatively for a simpler approach, it looks like exposing unroll as a parameter whilst defaulting it to 1, would also be reasonable.

WDYT? Happy to take a PR either way, thank you for the offer!

patrick-kidger avatar Apr 14 '25 18:04 patrick-kidger

(and the corresponding entries in the graphs) which indicates that our unrolling actually hinders performance. Which version of JAX/jaxlib is this on?

JAX/jaxlib v0.5.3. I don't deeply understand it but a bit of googling seems to suggest that unrolling can quite often be detrimental on CPU due to cache misses and the like. Instead on GPU, loop overheads (at least the way jax.lax.scancreates them) can be significant meaning unrolling can provide meaningful performance gains.

If not too hacky, the more complex solution sounds great if we define "large batch size" as 1 or 2 (i.e. any time we try to batch ) as:

  1. Fully unrolled, the lineax implementation is as efficient as jax (except for compile time) even for a batch size of 1.
  2. This avoids the need to do extensive tests on what the optimal batch size is for an unroll of 32 as a function of array size.

I will look into how complex the implementation would be, and if overly onerous will go with the simpler approach for now (perhaps with device-dependent defaults if possible so as not to deteriorate the performance of existing GPU code). The only real downside to the simpler approach is the higher compile time if fully unrolling on GPU.

I will also look into the JVP and transpose rule on both the lineax and jax side (jax definitely has custom rules but haven't delved into what they are). Without looking at the code and just scribbling on pen and paper my assumption is that auto-diff in lineax is calculated with a linear solve of RHS' - A' x (where ' denotes derivative). Perhaps it could possibly be true that at least for some usages of tridiagonal solve it is more efficient to trace through or jax has some other innovation? I will report back here with my findings regarding auto-diff and not making any changes affecting auto-diff in the initial PR until we've had a chance to discuss.

Thanks for the quick response! 🙂

jpbrodrick89 avatar Apr 14 '25 20:04 jpbrodrick89

'define large batch size as 2' sounds reasonable to me!

I think we should probably avoid fully unrolling, for specifically the reasons of compile time. I try to do a careful job of that in the Equinox ecosystem, and sometimes that means making a runtime trade-off like this.

As for the autodiff, the JVP is implemented here:

https://github.com/patrick-kidger/lineax/blob/2a18660733a0f202bc7f370394503b3cadfd3d0e/lineax/_solve.py#L141

which in particular picks up a bit of extra complexity to handle linear least squares problems (e.g. SVD). We handle these through the same API as they're a strict generalisation of linear solves.

And the transposition is handled here:

https://github.com/patrick-kidger/lineax/blob/2a18660733a0f202bc7f370394503b3cadfd3d0e/lineax/_solve.py#L275

Both of these rules are completely agnostic to the solver that is being used.

patrick-kidger avatar Apr 14 '25 22:04 patrick-kidger

So, it turns out that the relative slowdown for VJP is entirely eliminated if I explicitly set throw=False on my lineax fork (the JVP and transpose rule both override the user setting of throw as "there is nowhere to pipe the error to"). I'm not sure how much of this slowdown is due to having to re-compile the solve with a new static variable throw and how much is due toerror_if itself (which probably blocks some compiler optimisations?), but previous tests on the forward pass suggested that throw had a significant impact on performance (can't remember how much quantitatively right now but can easily test again if you're interested). How harmful/dangerous would it be to pass the primal value of throw to solves in the JVP and transpose rule? Or should we instead focus on reducing the performance impact of error_if in equinox? (Note the EQX_ON_ERROR environment variable is not set on my machines and I'm not sure what the default is (assuming raise), perhaps changing this might reduce the performance impact too.)

Image

jpbrodrick89 avatar Apr 15 '25 10:04 jpbrodrick89

After looking at the jax implementation of tridiagonal_solve to mirror the device-dependent lowering there we would indeed have to create a new custom primitive. Would there be any complications with this, considering that linear_solve_p is already itself a primitive?

If we decide against creating a new custom primitive and go with the simpler solution. Would it still be possible to use jax.default_backend() to conditionally set the default unroll if not provided? I'm not sure how jit friendly this is. (Alternatively, we could just add a static argument backend or use_jax.)

jpbrodrick89 avatar Apr 15 '25 10:04 jpbrodrick89

I don't think there should be any complications. (It's totally fine to have 'primitives inside primitives' like this.) If anything it should be easier, as there is no need to create jvp or transposition rules -- basically just impl+abstract+vmap+lowering rules.

As for error_if on the backward pass, it's definitely expected that this introduces a slowdown. My usual recommendation for this is to disable it with EQX_ON_ERROR=nan once a user feels their code is sufficiently robust to get away without error-checking. Sadly I don't know of a more performant way to handle this problem by default.

We could certainly check the default backend. That's probably a reasonable heuristic for which backend the code will then end up actually being used with.

patrick-kidger avatar Apr 15 '25 16:04 patrick-kidger

GPU batching in Jax has now been addressed meaning that the jax implementation beats the lineax Thomas algorithm:

Image

It is now only the CPU case that lineax slightly outperforms on. My belief is that this is because lineax uses the faster but potentially unstable Thomas algorithm and LAPACK/CUSPARSE use partial pivoting.

@patrick-kidger I understand that a stable tridiagonal solve has been on the laundry list #3 for a long time so I think the default tridiagonal solve should always use the jax.lax implementation. Do you agree? If so, the question is do we still want to make the Thomas algorithm available in some way for slightly faster CPU solves for symmetric/diagonally dominant tridiagonals? Would the preferred api for this be exposing simply as a solver option to linear_solve and defaulting to the stable approach where possible? More complex approaches could envisage SymmetricTridiagonalMatrix operator or default to using Thomas when we provide symmetric tag, so would be good to get consensus on interface before diving in. Also, it's not clear how a thomas_solver should behave on GPU, just go slow or default to the partial pivotting approach until we can expose cuThomas?

jpbrodrick89 avatar May 14 '25 08:05 jpbrodrick89

I agree!

This sounds like the performance gain here is in the margins + would increase maintenace cost + would open up a potential footgun for users. I'm inclined not to keep it around!

patrick-kidger avatar May 14 '25 18:05 patrick-kidger