lineax
lineax copied to clipboard
Improve Tridiagonal solve performance
WIP to fix #145
Tridiagonal solve uses jax.lax.linalg.tridiagonal_solve when on GPU and the lineax implementation with unroll=1 on CPU using JAX primitives. I think my abstract_eval is not general enough and am wondering whether I can get away with just importing jax.lax.linalg.standard_linalg_primitive (probably?)... If not, I'm undecided whether it's cleaner to have abstract_eval as a class method or a more verbosely named module-level function.
Done:
- [x] Lowering to CPU, GPU and other (unbatched)
- [x] Simple abstract eval
- [x] Single test for float64 (operator and vector) with array size of 200 unbatched on CPU and the expected speedup is observed.
TODO:
- [ ] Testing on GPU
- [ ] Implement a batching rule to use
lineaximplementation when batch size is greater than 1. - [ ] Test batching
- [ ] Test mixed dtypes (if
lineaxintend to support that) - [ ] Repeat my benchmarks
- [ ] Get
lineaxbenchmarks running again (see #146)?
Sharing now as away for Easter weekend, will continue next week, feel free to share any interim feedback, thanks.