lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Batch mode for `A X = B` with B a `n x m` matrix

Open vboussange opened this issue 1 year ago • 4 comments

Hey there, Some native JAX solvers such as jnp.linalg.solve and jax.scipy.sparse.linalg.gmres nicely support batch mode, where the right hand side of the system $A X = B$ is a $n \times m$ matrix. What is the best approach to efficiently reproduce this behaviour with lineax?

I made a benchmark using vmap and lineax, but this approach is is 4x slower:

import jax.numpy as jnp
import jax.random as jr
from jax import vmap, jit
import lineax as lx
import timeit


N = 20
key = jr.PRNGKey(0)
A = jr.uniform(key, (N, N))
B = jnp.eye(N, N)

@jit
def linalg_solve():
    x = jnp.linalg.solve(A, B)
    error = jnp.linalg.norm(B - (A @ x))
    return x, error

def lineax_solve(solver):
    operator = lx.MatrixLinearOperator(A)
    state = solver.init(operator, options={})
    def solve_single(b):
        x = lx.linear_solve(operator, b, solver=solver, state=state).value
        return x
    x = vmap(solve_single, in_axes=1, out_axes=1)(B)
    error = jnp.linalg.norm(B - (A @ x))
    return x, error

def benchmark(method, func):
    time_taken = timeit.timeit(func, number=10) / 10
    _, error = func()
    print(f"{method} solve error: {error:2e}")
    print(f"{method} average time: {time_taken * 1e3:.2f} ms\n")

benchmark("linalg.solve", linalg_solve)
# linalg.solve solve error: 6.581411e-06
# linalg.solve average time: 0.03 ms

myfun = jit(lambda: lineax_solve(lx.LU()))
benchmark("lineax", myfun)
# lineax solve error: 6.581411e-06
# lineax average time: 0.13 ms

vboussange avatar Nov 19 '24 07:11 vboussange

So (a) I think you've made a few mistakes in the benchmarking, and (b) most Lineax/Optimistix/Diffrax routines all finish with an option to throw a runtime error if things have gone wrong, and this adds a measurable amount of overhead on microbenchmarks such as this. This can be disabled with throw=False.

So adjusting things a little, I get exactly comparable results between the two approaches.

import jax
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
import timeit

@jax.jit
def linalg_solve(A, B):
    x = jnp.linalg.solve(A, B)
    error = jnp.linalg.norm(B - (A @ x))
    return x, error

@jax.jit
def lineax_solve(A, B):
    operator = lx.MatrixLinearOperator(A)
    def solve_single(b):
        x = lx.linear_solve(operator, b, throw=False).value
        return x
    x = jax.vmap(solve_single, in_axes=1, out_axes=1)(B)
    error = jnp.linalg.norm(B - (A @ x))
    return x, error

def benchmark(method, func):
    times = timeit.repeat(func, number=1, repeat=10)
    _, error = func()
    print(f"{method} solve error: {error:2e}")
    print(f"{method} min time: {min(times)}\n")

N = 20
key = jr.PRNGKey(0)
A = jr.uniform(key, (N, N))
B = jnp.eye(N, N)

linalg_solve(A, B)
lineax_solve(A, B)

benchmark("linalg.solve", lambda: jax.block_until_ready(linalg_solve(A, B)))
benchmark("lineax", lambda: jax.block_until_ready(lineax_solve(A, B)))

# linalg.solve solve error: 7.080040e-06
# linalg.solve min time: 4.237500252202153e-05
#
# lineax solve error: 7.080040e-06
# lineax min time: 3.9375037886202335e-05

Notable changes here:

  • Using throw=False to disable Lineax's checking for success (and just silently returning NaNs if things go wrong).
  • Using jax.block_until_ready.
  • Compiling prior to evaluating, so that we don't measure differences in compilation speed.
  • Using min with repeat=10, rather than the mean, over the evaluation times. As benchmarking noise is one-sided then this is usually the correct aggregation method for microbenchmarks.
  • Actually passing in inputs to the JIT'd region. What you've written here could in principle be entirely constant-folded by the compiler.

FWIW I've also trimmed out the use of state and the explicit lineax.LU() solver, as the former is done already inside the solve and the latter is the default.

patrick-kidger avatar Nov 20 '24 20:11 patrick-kidger

Excellent, thanks for the details!

FWIW I've also trimmed out the use of state and the explicit lineax.LU() solver, as the former is done already inside the solve and the latter is the default.

I am surprised that the vmap is not triggering multiple internal init?

vboussange avatar Nov 21 '24 09:11 vboussange

I am surprised that the vmap is not triggering multiple internal init?

init is called only on the non-vmap'd input A, so it won't be vmap'd.

patrick-kidger avatar Nov 21 '24 10:11 patrick-kidger

Of course, makes total sense. Thanks for the details!

vboussange avatar Nov 25 '24 10:11 vboussange