lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Is there a way to auto-recognise operator structure?

Open jpbrodrick89 opened this issue 8 months ago • 10 comments

I am looking to extend the diffrax example on the 1D heat equation to use a Tridiagonal solve, unfortunately lineax can not tell a generic 2D array has tridiagonal structure. If this is not possible, could we make a change to optimistix to create the desired LinearOperator from the jacobian

MWE for context:

import jax.numpy as jnp
import jax
import lineax
import optimistix
jax.config.update("jax_enable_x64", True)

def laplacian_numerator(y):
    return jnp.diff(jnp.diff(y), prepend=0., append=0.)

def diffus_step(y, args):
    return y + 0.1 * laplacian_numerator(y)

solver = optimistix.Newton(rtol=1e-8, atol=1e-8, linear_solver=lineax.Tridiagonal())

key = jax.random.PRNGKey(0)
rhs = jax.random.normal(key, (100,))

@jax.jit
def loss(y, args):
    return diffus_step(y, args) - rhs

sol = optimistix.root_find(lambda y, args: diffus_step(y, args) - rhs, solver, rhs, max_steps=1, throw=False)

Returns the follow traceback

...
--> 220 return iterative_solve(
    221     fn,
    222     solver,
    223     y0,
    224     args,
    225     options,
    226     max_steps=max_steps,
    227     adjoint=adjoint,
    228     throw=throw,
    229     tags=tags,
    230     f_struct=f_struct,
    231     aux_struct=aux_struct,
    232     rewrite_fn=_rewrite_fn,
...
     53         "matrices"
     54     )
     55 return tridiagonal(operator), pack_structures(operator)

ValueError: `Tridiagonal` may only be used for linear solves with tridiagonal matrices

jpbrodrick89 avatar May 06 '25 16:05 jpbrodrick89

Yup! So at least from Optimistix then you can pass optx.root_find(..., tags=frozenset({lineax.tridiagonal_tag})) to specify the structure of the Jacobian of the target function. Note that this is a promise that is not checked by either Lineax or Optimistix.

We don't try to determine the structure automatically at runtime as that would imply compiling every linear solver and then dispatching appropriately... and compiling every linear solver would be a huge compile-time overhead. Correspondingly we ask for a compile-time flag if required instead.

(FWIW the above is not great UX and one that I would be happy to take improvements on, I just haven't found a better way to express things.)

Since you mention Diffrax, then FWIW our RK solvers currently don't offer a way to specify the tags in the same way. We could add this as an argument to them. Although for PDE purposes perhaps you're writing your own Crank-Nicholson (etc) solver yourself, in which case hopefully you can specify the tags as above.

patrick-kidger avatar May 06 '25 16:05 patrick-kidger

Thanks Patrick, this worked great, perfectly happy with a "promise" rather than a check. We are currently doing manual operator splitting with implicit/explicit Euler steps, so using optimistix directly should be fine for now. However, if we want to use higher order methods in the future, exposing the tags argument is probably a good idea. However, the tags could probably be better documented in optimistix perhaps by way of an example.

I can confirm that simply adding the tag in decreases runtime by about 30% on CPU on my lineax branch with unroll=1 in the solver (apologies for the delay on the PR for this I am looking to see if we can improve jax batched performance first to greatly simplify the whole affair).

I did some perfomance comparisons, there seems to a fixed overhead of 15–50µs that is noticeable at small matrix sizes (<150) compared to starting from a jitted jacobian function. (I previously saw lineax be slower at large matrix sizes but this was because I was using jacrev instead of jacfwd). Any way we can cut this down? However, both the former are a lot slower than when given the full matrix or building the diagonal operator with ones_like. The large multiplicative difference is almost certainly due to the dense jacobian calculation time. Almost all of the difference is eliminated when I use sparsejac (except for when array size exceeds ~5000 which I'm thinking might be due to the memory of instantiating the matrix). Could we use sparsejac or a homebrew equivalent in optimistix (maybe reusing the lineax tag)?

Image

The functions I used were as follows (starting from code above):

import jax.experimental.sparse
import sparsejac

@jax.jit
def with_root_find(rhs):
    return optimistix.root_find(loss, solver, rhs, tags=frozenset({lineax.tridiagonal_tag}), max_steps=1, throw=False)

jac_func = jax.jacobian(diffus_step)

@jax.jit
def jac_func_to_sol(rhs):
    jac = jac_func(rhs, None)
    op = lineax.MatrixLinearOperator(jac, tags=frozenset({lineax.tridiagonal_tag}))
    return lineax.linear_solve(op, rhs, throw=False)

def tridiag_sparsity(x):
  dummy_tridiag = jnp.diag(x) + jnp.diag(x[:-1], 1) + jnp.diag(x[1:], -1)
  return jax.experimental.sparse.BCOO.fromdense(dummy_tridiag)

@jax.jit
def sparsejac_func_to_sol(rhs):
    with jax.ensure_compile_time_eval():
       sparsejac_func = sparsejac.jacfwd(diffus_step, sparsity=tridiag_sparsity(rhs))
    jac = sparsejac_func(rhs, None).todense()
    op = lineax.MatrixLinearOperator(jac, tags=frozenset({lineax.tridiagonal_tag}))
    return lineax.linear_solve(op, rhs, throw=False)

@jax.jit
def jac_mat_to_sol(mat, rhs):
    op = lineax.MatrixLinearOperator(mat, tags=frozenset({lineax.tridiagonal_tag}))
    return lineax.linear_solve(op, rhs, throw=False)

@jax.jit
def using_ones_like(rhs):
    du = jnp.ones_like(rhs[:-1])*0.1
    d = jnp.ones_like(rhs)*0.8
    d = d.at[0].set(0.9)
    d = d.at[-1].set(0.9)
    dl = jnp.ones_like(rhs[1:])*0.1
    op = lineax.TridiagonalLinearOperator(d, dl, du)
    return lineax.linear_solve(op, rhs, throw=False)

jpbrodrick89 avatar May 07 '25 01:05 jpbrodrick89

I can confirm that simply adding the tag in decreases runtime by about 30% on CPU on my lineax branch with unroll=1 in the solver

Awesome!

(apologies for the delay on the PR for this I am looking to see if we can improve jax batched performance first to greatly simplify the whole affair).

No worries, that sounds reasonable to me.

Is there anyway to reduce this almost 2x hit? If not, I'm guessing optimistix should really only be used when the operator is nonlinear and we need the iterations?

Ah, this is interesting. You're right, it would be conceptually very elegant if we could just always use Optx and have that still do the efficient thing on linear problems.

Of the top of my head I'm not sure where the extra Optx overhead is coming from. Is it additive or multiplicative as the problem size changes? If we can identify which part of the stack is introducing the overhead then perhaps this is an easy fix.

patrick-kidger avatar May 07 '25 09:05 patrick-kidger

I just updated the comment above which should answer you question on scaling of the overhead (it's ~constant). I have not been able to identify where it's coming from yet. Probably needs a trawl through the optimistix source code. Any point looking at the jaxpr or some representation of the compiled xla?

Switching to a sparse jacobian calculation would really be the big win here for large matrix sizes.

jpbrodrick89 avatar May 07 '25 10:05 jpbrodrick89

I think the 2x runtime in Optimistix is because we're taking an extra step by default, since we evaluate the convergence of a Cauchy sequence on y_diff and f_diff.

Cauchy termination is optional for the Newton and Chord root finders, but in its absence we also only terminate after at least two steps. Code here, including a #TODO for the case we're discussing here.

johannahaffner avatar May 07 '25 16:05 johannahaffner

Aha, @johannahaffner recalls details that I don't! Indeed this is probably achievable by adjusting the solver definitions then. (I would need to think about how.)

patrick-kidger avatar May 07 '25 19:05 patrick-kidger

After reducing a lot of the performance discrepancy using your tips above and my unmerged colouring method for tridiagonal lineax PR #165 I have tried to find the source of the remaining discrepancy. I can confirm this is definitely NOT because optimistix is calling 2 steps. I then thought to look at whether the termination checks were a bottleneck (hence my recent equinox question) and wrote my own termination routine that always set terminate to False without checking anything (to ensure max steps were completed. It seems that checks do have a fixed impact on performance but these quickly become negligible for larger array sizes in my case (Cauchy termination apparently uses fx from the previous step rather than calculating a new one which helps performance). The only other discrepancy I could think of was the while loop and turn out this was the culprit!

Replacing the while_loop in optimistix._iterate with a jax.lax.fori_loop and running for max_steps (1 here) eliminated ALL the performance discrepancy between optimistix.root_find and using FunctionLinearOperator. Somehow the overhead of the while loop seems dependent on array size with the impact being about 20% approaching 0.1ms for an array size of 20k. Happy to push out further to test the persistence of this trend.

Image

Image

Image

The only way I can think of to support opting into fori_loop in the public API is to have fixed_termination kwarg that always runs for max_steps . As the cost of an initial iteration will almost always be greater than the while loop overhead I fully appreciate this might not be worth complicating the API for but if one wants to always run for 2 steps I don't think either of the current termination options would decide to stop early very often so there could be a use in this. WDYT?

Note I believe the increased cost of FunctionLinearOperator compared to JacobianLinearOperator for intermediate array sizes is likely due to the memory requirements of jax.linearize as opposed to the vmapping over jax.jvp.

Happy to discuss over at optimistix if you want to dig in further.

jpbrodrick89 avatar Jun 27 '25 11:06 jpbrodrick89

Interesting! So the while_loop is provided by the choice of adjoint method, and e.g. the default ImplicitAdjoint is actually just providing a lax.while_loop... which is exactly the same thing that lax.fori_loop is just a nice wrapper for.

Is the overhead perhaps coming just from checking our condition function in the loop, then?

patrick-kidger avatar Jun 27 '25 22:06 patrick-kidger

I looked at the jax.lax.fori_loop documentation and apparently if the trip count is static it lowers to jax.lax.scan rather than jax.lax.while. I think this will be the key distinction but I can try run some isolated tests to find out.

jpbrodrick89 avatar Jun 27 '25 22:06 jpbrodrick89

Sorry, I wasn't sufficiently clear: lax.fori_loop lowers to lax.scan... which then itself lowers to lax.while_loop, just with a very simple cond_fn that checks the number of iterations. (Hence my comment on the condition function.)

Let me know what you find!

patrick-kidger avatar Jun 28 '25 09:06 patrick-kidger