lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Utilise pivoted QR functionality added in JAX v0.5.1

Open tttc3 opened this issue 10 months ago • 2 comments

Since JAX v0.5.1 supports pivoted QR factorisations on CPU backends, and GPU backends (via MAGMA), the following limitation outlined in the QR linear solver docstring can now be overcome:

class QR(AbstractLinearSolver, strict=True):
    """QR solver for linear systems.

    This solver can handle non-square operators.

    This is usually the preferred solver when dealing with non-square operators.

    !!! info

        Note that whilst this does handle non-square operators, it still can only
        handle full-rank operators.

        This is because JAX does not currently support a rank-revealing/pivoted QR
        decomposition, see [issue #12897](https://github.com/google/jax/issues/12897).

        For such use cases, switch to [`lineax.SVD`][] instead.
    """

However, I'm not sure of the best way to implement this, considering that QR currently utilises the fact it is not rank-revealing to improve the JVP efficiency as noted in QR.allow_dependent_columns(...)

tttc3 avatar Mar 14 '25 22:03 tttc3

Hijacking this to note that https://github.com/jax-ml/jax/commit/e03fe3a06d1567bf738f859894a0fd98e6be4d6d also adds support for a QR implementation of SVD, which should improve numerical stability.

johannahaffner avatar Mar 14 '25 22:03 johannahaffner

Ah, nice! I'd be happy to take a PR adding support for this.

For the JVP, perhaps we should just disable that? It's a fairly subtle implementation detail. (Or we could just add a flag.)

patrick-kidger avatar Mar 15 '25 14:03 patrick-kidger