Streamline tridiagonalisation of JacobianLinearOperator/FunctionLinearOperator using coloring method
BREAKING CHANGE: Tridiagonals are no longer "extracted" from operators but rely on the promise of the tridiagonal tag being accurate and fulfilled. If this promise is broken solutions may differ from user expectations.
Preface
Very similar in spirit to #164 and partially addresses the problem discussed in #149. This turned out to be a lot simpler than I anticipated and thought it would be helpful to compare and contrast the two implementations and performance impacts. In this case, the performance impact is significant and consistent for array sizes > 75 even when the JVP batching rule is efficient. However, this implementation does not support primitives with no batching rule is available at all. As I am just vmapping over three vectors perhaps I could write an explicit for loop but probably best to prioritise efficiency here over flexibility here (or maybe we can do both if we can query whether the rule exists within a jit function?).
Some of the variable names and overall spirit of the method is inspired direct from sparsejac so we should probably add attribution (under MIT license) somewhere.
Changes made
The single-dispatch tridiagonal function has been re-written for JacobianLinearOperator/FunctionLinearOperator based on the observation made by coloring methods that all elements of a tridiagonal matrix can be disentagled from just three carefully constructed vector pre/post-multiplications. I have gone with the simple case where each vector consists of ones in every third position and zeros otherwise. This means the full Jacobian need never be calculated.
Performance Impact
Applying this to the problem discussed in #149 where we attempt to use optimistix to provide an implicit solve of a diffusion equation and compare this with straightforward lineax.linear_solves with various operators (with EQX_ON_ERROR=nan):
Code
import jax.numpy as jnp
import jax
import lineax
import optimistix
import sparsejac
jax.config.update("jax_enable_x64", True)
#jax.config.update("jax_disable_jit", True)
key = jax.random.PRNGKey(0)
def laplacian_numerator(y):
return jnp.diff(jnp.diff(y), prepend=0., append=0.)
def diffus_step(y, args):
return y + 0.1 * laplacian_numerator(y)
def loss(y, args):
return diffus_step(y, args) - y
tridiag_solver = lineax.Tridiagonal()
newton_solver = optimistix.Newton(rtol=1e-8, atol=1e-8, linear_solver=lineax.Tridiagonal())
@jax.jit
def with_root_find(rhs):
return optimistix.root_find(loss, newton_solver, rhs, tags=frozenset({lineax.tridiagonal_tag}), max_steps=1, throw=False)
@jax.jit
def jac_func_to_sol(rhs):
jac = jax.jacfwd(diffus_step)(rhs, None)
op = lineax.MatrixLinearOperator(jac, tags=frozenset({lineax.tridiagonal_tag}))
return lineax.linear_solve(op, rhs, tridiag_solver, throw=False)
def tridiag_sparsity(x):
dummy_tridiag = jnp.diag(x) + jnp.diag(x[:-1], 1) + jnp.diag(x[1:], -1)
return jax.experimental.sparse.BCOO.fromdense(dummy_tridiag)
@jax.jit
def sparsejac_func_to_sol(rhs):
with jax.ensure_compile_time_eval():
sparsejac_func = sparsejac.jacfwd(diffus_step, sparsity=tridiag_sparsity(jnp.ones_like(rhs)))
jac = sparsejac_func(rhs, None).todense()
op = lineax.MatrixLinearOperator(jac, tags=frozenset({lineax.tridiagonal_tag}))
return lineax.linear_solve(op, rhs, tridiag_solver, throw=False)
@jax.jit
def using_ones_like(rhs):
du = jnp.ones_like(rhs[:-1])*0.1
d = jnp.ones_like(rhs)*0.8
d = d.at[0].set(0.9)
d = d.at[-1].set(0.9)
dl = jnp.ones_like(rhs[1:])*0.1
op = lineax.TridiagonalLinearOperator(d, dl, du)
return lineax.linear_solve(op, rhs, tridiag_solver, throw=False)
@jax.jit
def jlo_to_sol(rhs):
op = lineax.JacobianLinearOperator(diffus_step, rhs, tags=frozenset({lineax.tridiagonal_tag}), jac="fwd")
return lineax.linear_solve(op, rhs, tridiag_solver, throw=False)
Comparison for various sizes of rhs (using EQX_ON_ERROR=nan and jax_enable_x64 with all other environment variables at their defaults) :
We see that this PR will lead to improvements in Newton solve (which uses FunctionLinearOperator under the hood) and linear_solve with JacobianLinearOperator for array sizes beyond 75. JacobianLinearOperator is now as efficient as TridiagonalLinearOperator (at least with a linear function). A performance gain >15x is observed for array sizes of 1E3 and of greater 400x for array sizes of 1E4. The performance hit for smaller array sizes is probably again due to threading which I haven't controlled here (happy to look into it). As discussed in #149 the factor of 2 performance hit of the Newton solve is likely due to the Caucy termination test which requires two steps.
I have not done testing for custom primitives yet, let me know whether I should add this.
Testing done
- [x] CI passes
- [x] Added
test_tridiagonaltest intest_operators.py - [x] Added
make_jacfwd_operatorandmake_jacrev_operatorinhelpers.pyto ensure this continues to work whenjac=True
I addressed the inefficiencies compared to sparsejac and this now runs as fast as TridiagonalLinearOperator!! I also added a test for jac="bwd" in JacobianLinearOperator. There's a chance CI might fail now due to the new make_operators and I might need to key some more skips in (as jac="bwd" doesn't support complex128, not sure if it's a good idea to add holomorphic=True as default).
We might want to consider adding attribution to sparsejac somewhere under their MIT license.
I will update the original PR comment.
CI passes on my machine with 3.13 and on GH with 3.10. The only failure on GH with 3.12 is with GMRES I get this error:
equinox.EquinoxRuntimeError: A form of iterative breakdown has occured in a linear solve. Try using a different solver for this problem or increase `restart` if using GMRES.
Nice, I like this in the same way as #164. I think the same comments I have there apply here, otherwise this essentially LGTM.
I'm happy to assume the presence of a vmap rule where necessary. I thnk if need be then the custom primitive can be provided a vmap rule that just does a for loop at that level, rather than attempting to hoist it out here.
Hi, just wanted to ask what the status is for this PR? We got a request for a new release, and I want to compile a To-Do list for what we want to have in before we do a release.
See #164 for current blockers there that apply here too, thanks!
See #164 for current blockers there that apply here too, thanks!
Thank you :) I'll familiarise myself with the state of the discussion.