lineax icon indicating copy to clipboard operation
lineax copied to clipboard

Replace allow_dependent_rows and allow_dependent_columns with assume_full_rank

Open adconner opened this issue 7 months ago • 8 comments

The two functions allow_dependent_{rows,columns} together did the job of answering if the solver accepts full rank matrices for the purposes of the jvp. Allowing them to be implemented separately created some issues:

  1. Invalid states were representable. Eg. What does it mean that dependent columns are allowed for square matrices if dependent rows are not? What does it mean that dependent rows are not allowed for matrices with more rows than columns?
  2. As the functions accept operator as input, a custom solver could in principle decide its answer based on operator's dynamic value rather than only jax compilation static information regarding it, as in all the lineax defined solvers. This would prevent jax compilation and jit.

Both issues are addressed by asking the solver to report only if it assumes the input is numerically full rank. If this assumption is exactly violated, its behavior is allowed to be undefined, and is allowed to error, produce NaN values, and produce invalid values.

adconner avatar Jun 09 '25 20:06 adconner

If you wanted to go further than this PR and did not mind a breaking change, you can make full_rank an operator tag. Then solvers requiring full_rank operators check this just like they check any tag, and the jvp checks the operator tag when choosing the jvp terms to emit. This has benefits

  1. More general solvers also benefit from a faster jvp when called with operators statically known (or assumed) to be full rank
  2. AutoLinearSolver can now be fully automatic, not requiring well_posed for an argument
  3. It makes all static assumptions about operators needed by solvers explicit in operator tags

Users would have to update their code to tag most operators as full rank

adconner avatar Jun 10 '25 14:06 adconner

Hey there! So you're actually touching on a design point from way back in the early days of Lineax. We in fact originally had a maybe_singular_tag on the operator rather than a description of what could be handled by the solver.

We eventually nixed this because it was too much of a footgun.

  • If Lineax defaulted to assuming well-posedness and you left off the tag maybe_singular_tag then you'd silently incorrect JVPs.
  • If Lineax defaulted to assuming singularity (as you seem to be suggesting here) and you left off the full_rank tag then you'd get silently expensive JVPs. Or conversely if adding that tag because the accepted default, then forgetting to remove it puts you back in the land of silently incorrect JVPs.

Either way this was clearly going to go wrong for a lot of users a substantial fraction of the time!

Making this a property of the solver instead reflects real-world usage. You're probably only using SVD if you have a singular problem, for example, and you're anyway happy to accept the extra computational cost in the JVP rule.


As for the adjustment you're originally making in this PR: being able to distinguish these two cases is useful for the sake of faster JVPs, in particular for QR solves.

On (1) it's true that illegal states are representable. We could probably add some assert statements to _linear_solve_jvp to catch the cases you highlight, which are all determinable given compiletime information. On (2) indeed it has to be static; this is already captured in its type annotation -> bool which is not a JAX array; it is an error for a solver to return anything else -- as meaningless as returning a string or object() etc.

WDYT?

patrick-kidger avatar Jun 11 '25 18:06 patrick-kidger

If Lineax defaulted to assuming singularity (as you seem to be suggesting here) and you left off the full_rank tag then you'd get silently expensive JVPs. Or conversely if adding that tag because the accepted default, then forgetting to remove it puts you back in the land of silently incorrect JVPs.

Making this a property of the solver instead reflects real-world usage. You're probably only using SVD if you have a singular problem, for example, and you're anyway happy to accept the extra computational cost in the JVP rule.

The user at some point needs to indicate whether they assume the operators are full rank. Currently they do this in their solver selection if manual, or in the well_posed argument to AutoLinearSolver. The problem (?) of the users having to correctly identify their assumptions is already present and exists irrespective of if we express the assumptions in the solver or the operator.

Keep in mind that if the user forgets the full_rank tag and uses a full rank only solver, like QR, they will get an error, just like if they forgot the positive semidefinite tag for the Cholesky solver. It is true that AutoLinearSolver would be SVD for undecorated operators, but this is the only source of silent inefficiencies (and this is also analagous to the psd tag: the user will already get silent inefficiencies from the auto solver selection if they do not tag a known psd operator)

Whats more, there is an additional benefit to the operator (rather than the solver) knowing whether it is full rank where there isnt for another static assertion like psd, as more efficient jvps can be emitted for that operator even using solvers supporting more general operators.

As for the adjustment you're originally making in this PR: being able to distinguish these two cases is useful for the sake of faster JVPs, in particular for QR solves.

The JVPs of this approach are identical to those of the previous approach including for QR. Notice that in the calculation of the jvp, we now recreate the old info as needed. Any values of allow_dependent_{rows,columns} which are not functions of the operator being {tall,square,wide} and whether it is assumed to be full rank are invalid.

adconner avatar Jun 11 '25 19:06 adconner

It is true that AutoLinearSolver would be SVD for undecorated operators, but this is the only source of silent inefficiencies (and this is also analagous to the psd tag: the user will already get silent inefficiencies from the auto solver selection if they do not tag a known psd operator)

Actually, there is one further problem here (and one that is a motivating reason for us not to do this ): if a user has what they believe to be a definite operate, but mistakenly provides an indefinite operator, then SVD will silently compute a pseudoinverse (least-square/least-norm) solve, rather than erroring out as would be desirable.

Any values of allow_dependent_{rows,columns} which are not functions of the operator being {tall,square,wide} and whether it is assumed to be full rank are invalid.

On this part, I think I like these changes. This was always a subtle point I have to think a lot about.

I think my main comment is that in general I think this could still be a function of any structure/sparsity in the operator? I'd need to noodle on this to be sure I'm getting it right, but I could believe that certain solvers might exhibit or not exhibit this behaviour depending on the operator. (The trivial example is of course a user-defined assume_full_rank tag that the solver checks and respects! But perhaps there are other examples?)

patrick-kidger avatar Jun 14 '25 22:06 patrick-kidger

Actually, there is one further problem here (and one that is a motivating reason for us not to do this ): if a user has what they believe to be a definite operate, but mistakenly provides an indefinite operator, then SVD will silently compute a pseudoinverse (least-square/least-norm) solve, rather than erroring out as would be desirable.

I see. So you're concerned about the situation that the user is using the SVD solver, intending that their operator is assumed full rank, but SVD ignores this and always does rcond filtering of singular values. If this is concerning it seems like a problem already in the existing implementation, one with a simple fix: Do the same thing that you do with the Diagonal solver and make svd take a well_posed argument. If true, svd assumes all singular values nonzero (rcond = 0, basically). Of course, if operators know whether they are full rank, both this argument and the one in Diagonal (and future pivoted QR) are unneeded, as they would just respect the operator tag.

Or is your concern that for the question of operator rank, you want the user to be forced to be explicit, and any default behavior is potentially confusing? If this is the case, you could just as well require all operators to be decorated with exactly one of a full_rank, possibly_not_full_rank tag (or perhaps some other equivalent design). My suggestion is only about which object the assumption should be encoded into, not what the default value should be. On the choice (default assume full, default no assumption, require explicit), I don't have a strong opinion and there are reasonable arguments for all 3 options.

On this part, I think I like these changes. This was always a subtle point I have to think a lot about.

I think my main comment is that in general I think this could still be a function of any structure/sparsity in the operator? I'd need to noodle on this to be sure I'm getting it right, but I could believe that certain solvers might exhibit or not exhibit this behaviour depending on the operator. (The trivial example is of course a user-defined assume_full_rank tag that the solver checks and respects! But perhaps there are other examples?)

Structure and sparsity in the operator could potentially mean it is structurally rank deficient, and in this situation you would want the operator to know this, but still the only question that the jvp cares about is whether the rows of the operator are independent and whether the columns of the operator are indepndent. The answer to this question, mathematically, for any operator, is given only by the table

full rank? operator size cols independent rows independent
yes tall yes no
yes square yes yes
yes wide no yes
no any no no

The jvp knows the size of the operator, so the the only information it is missing is whether the operator is full rank.

adconner avatar Jun 14 '25 23:06 adconner

so the the only information it is missing is whether the operator is full rank

So let's consider a solver that consumes a 2nxn matrix and a 2n-length vector and returns an n-length solution, and for which assume_full_rank = False; that is to say it will always return a pseudoinverse solution.

I've not yet told you anything about the structure of the matrix or the behaviour of the solver.

According to the description you've given here, we would always compute a more-expensive JVP, corresponding to the allow-dependent-rows case.

However! It just so happens that I have a case in mind for which we may only need the inexpensive JVP. (The part corresponding to a true linsolve solution, not the more general pseudoinverse solution).

Namely, we introduce all of the following:

# Consumes two n x n operators and stacks them to create a 2n x n operator.
class TwoOperatorsInATrenchcoast(AbstractLinearOperator):
    first: AbstractLinearOperator
    second: AbstractLinearOperator

    ...

# Matrix of all zeros
class ZeroLinearOperator(AbstractLinearOperator):
    ...

def is_full_and_zero(operator: AbstractLinearOperator) -> bool:
    return (isinstance(operator, TwoOperatorsInATrenchcoast) and isinstance(operator.second, ZeroLinearOperator))

class WeirdSolver(AbstractLinearSolver):
    ...

    def allow_dependent_rows(self, operator):
        return not is_full_and_zero(operator)

    def compute(self, ...):
        if is_full_and_zero(operator) and is_lower_half_zero(vector):  # second function elided for brevity, can be statically evaluatable if there is suitable pytree structure (just as with the operator).
            # Perform an LU solve on the upper half, then pad with zeros
        else:
            # Perform an SVD solve

In this case we have that WeirdSolver is indeed satisfying its contract of only returning a pseudoinverse solution. In the fastpath branch, this coincides with performing a full-rank solve on a submatrix. When on that branch we can skip the extra JVP computation (the tangents in the zero regions are also structurally zero), and this is reflected in allow_dependent_rows.

Thus we have obtained a case in which the solver needs to consume the operator to determine the answer to how it handles dependent rows or not.


The above is obviously highly contrived. But I think it indicates that completely factoring out the operator height-vs-width directly into the JVP rule would indeed lose expressivity.

(Now, one could certainly make the argument that this is a worthy trade-off in the name of simplicity. And if we done this the first time around I'd agree, but as it is I'm inclined to keep things for the sake of backwards-compatibility.)

WDYT?

patrick-kidger avatar Jun 27 '25 22:06 patrick-kidger

I now understand your concern. At the very least the current allow_dependent_{rows,columns} are misnamed. Its really more like allow_nonconstant_{row,column}_space, as row span A(x) or column span A(x) is independent of x if and only if the corresponding term of the jvp vanishes (for some neighborhood of x's). The rows or columns being independent are just special cases of this notion where the row/column span is constantly the whole vector space.

I could potentially get behind the idea that we might want to expose this possibility (better documented and with more descriptive names) to the user. However, at least first consider that the extra generality we are exposing (a situation where the solver detects that the operator statically has constant row/column space) can already be obtained in the case where it is needed without this feature. If the row span of A(x) is constant, then A(x) factors as A(x) = B C(x), with C(x) has independent rows and B is constant. If the user is in this case they can compute A^+ b = C^+ B^+ b and get the same jvp.

Furthermore, the only nontrivial instance of this property which can ever occur structurally (only looking at the sparsity pattern of A) is precisely the one of your contrived example, where we discard rows/columns which are zero. And the solver only has access to structural information about A, so the only extra generality we are enabling is detection of some zero rows or columns inside the solver. I think I might prefer the simpler system where the solver doesnt make an attempt to detect structurally zero rows/columns, and instead push this functionality to some solver wrapper which structurally factors A(x) = B C(x) D and computes A^+ b = B^T C(x)^+ D^T b, where B is a subset of columns of the identity matrix and D is a subset of rows of the identity matrix (if it is actually desired to implement in the library, rather than by the user).

adconner avatar Jun 28 '25 09:06 adconner

Sorry for taking so long to get back to you; as this is a very technical + mostly internal change, it's been pushed down my priority list.

Anyway, I think we agree with each other then! As per my last message:

(Now, one could certainly make the argument that this is a worthy trade-off in the name of simplicity. And if we done this the first time around I'd agree, but as it is I'm inclined to keep things for the sake of backwards-compatibility.)

Indeed a simpler solution could have been arrived at, at the cost of an acceptable loss of generality. However breaking backward compatibility requires a very high bar, and I don't think that adjusting where we land on this trade-off meets that.

I am concious of your other work in #159. I'd definitely still be very happy to get that in if you'd be willing to update it without this PR as a dependency?

patrick-kidger avatar Jul 18 '25 07:07 patrick-kidger

Okay, I am very belatedly getting back around to this PR now. (Personal life issues caught up for a while.)

I've just merged this into our dev branch. I really appreciate your patience, I think this is an improvement really worth having.

Also CC @johannahaffner who I think is coordinating the next release – this is a breaking change so we'll bump the version appropriately. :)

patrick-kidger avatar Dec 05 '25 00:12 patrick-kidger