optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Why is jax.lax.custom_root not used?

Open sbodenstein opened this issue 3 months ago • 1 comments

Why does optimistix use a custom jvp rather than the jax.lax.custom_root primitive? https://docs.jax.dev/en/latest/_autosummary/jax.lax.custom_root.html

sbodenstein avatar Oct 06 '25 14:10 sbodenstein

So if you're asking in terms of JAX abstractions, then jax.lax.custom_root is also backed by a custom_jvp so they're very similar in that regard:

https://github.com/jax-ml/jax/blob/349cbc24bfe716c91e8e6af6b12a9100766c9783/jax/_src/lax/control_flow/solves.py#L131-L132

If you're asking in terms of practical reasons, then
(a) it's pretty straightforward code to write, so this gives us greater control to fix and tweak things to just have it ourselves;
(b) the necessary splitting apart of dynamic/static, piping through Lineax, etc., into jax.lax.custom_root would result in less clean code and about as many lines of code as the implementation we have here.

(The fact that we have a custom linear solve primitive, rather than using jax.lax.custom_linear_solve, is the more serious version of this question 😁 )

patrick-kidger avatar Oct 06 '25 15:10 patrick-kidger