Why is jax.lax.custom_root not used?
Why does optimistix use a custom jvp rather than the jax.lax.custom_root primitive? https://docs.jax.dev/en/latest/_autosummary/jax.lax.custom_root.html
So if you're asking in terms of JAX abstractions, then jax.lax.custom_root is also backed by a custom_jvp so they're very similar in that regard:
https://github.com/jax-ml/jax/blob/349cbc24bfe716c91e8e6af6b12a9100766c9783/jax/_src/lax/control_flow/solves.py#L131-L132
If you're asking in terms of practical reasons, then
(a) it's pretty straightforward code to write, so this gives us greater control to fix and tweak things to just have it ourselves;
(b) the necessary splitting apart of dynamic/static, piping through Lineax, etc., into jax.lax.custom_root would result in less clean code and about as many lines of code as the implementation we have here.
(The fact that we have a custom linear solve primitive, rather than using jax.lax.custom_linear_solve, is the more serious version of this question 😁 )