Utilise pivoted QR functionality added in JAX v0.5.1
Since JAX v0.5.1 supports pivoted QR factorisations on CPU backends, and GPU backends (via MAGMA), the following limitation outlined in the QR linear solver docstring can now be overcome:
class QR(AbstractLinearSolver, strict=True):
"""QR solver for linear systems.
This solver can handle non-square operators.
This is usually the preferred solver when dealing with non-square operators.
!!! info
Note that whilst this does handle non-square operators, it still can only
handle full-rank operators.
This is because JAX does not currently support a rank-revealing/pivoted QR
decomposition, see [issue #12897](https://github.com/google/jax/issues/12897).
For such use cases, switch to [`lineax.SVD`][] instead.
"""
However, I'm not sure of the best way to implement this, considering that QR currently utilises the fact it is not rank-revealing to improve the JVP efficiency as noted in QR.allow_dependent_columns(...)
Hijacking this to note that https://github.com/jax-ml/jax/commit/e03fe3a06d1567bf738f859894a0fd98e6be4d6d also adds support for a QR implementation of SVD, which should improve numerical stability.
Ah, nice! I'd be happy to take a PR adding support for this.
For the JVP, perhaps we should just disable that? It's a fairly subtle implementation detail. (Or we could just add a flag.)