lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Improve Tridiagonal solve performance

Open jpbrodrick89 opened this issue 8 months ago • 0 comments

WIP to fix #145

Tridiagonal solve uses jax.lax.linalg.tridiagonal_solve when on GPU and the lineax implementation with unroll=1 on CPU using JAX primitives. I think my abstract_eval is not general enough and am wondering whether I can get away with just importing jax.lax.linalg.standard_linalg_primitive (probably?)... If not, I'm undecided whether it's cleaner to have abstract_eval as a class method or a more verbosely named module-level function.

Done:

  • [x] Lowering to CPU, GPU and other (unbatched)
  • [x] Simple abstract eval
  • [x] Single test for float64 (operator and vector) with array size of 200 unbatched on CPU and the expected speedup is observed.

TODO:

  • [ ] Testing on GPU
  • [ ] Implement a batching rule to use lineax implementation when batch size is greater than 1.
  • [ ] Test batching
  • [ ] Test mixed dtypes (if lineax intend to support that)
  • [ ] Repeat my benchmarks
  • [ ] Get lineax benchmarks running again (see #146)?

Sharing now as away for Easter weekend, will continue next week, feel free to share any interim feedback, thanks.

jpbrodrick89 avatar Apr 18 '25 02:04 jpbrodrick89