ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Rules for det are not general and likely unstable

Open cortner opened this issue 2 years ago • 8 comments

A = [1.0 1.0; 1.0 1.0] 
det(A) * inv(A) 
# ERROR: SingularException(2)

to be fair, the LU approach is remarkably resiliant...

A = [1.0 1.0; 1.0 1.0-1e-16]
det(A) * inv(A)
# 2×2 Matrix{Float64}:
#   1.0  -1.0
#  -1.0   1.0

But I am nervous that this is too simple a special case to draw any conclusions.

Either way - it seems to me that an implementation that always works and is guaranteed numerically stable is preferrable. If the maintainers agree, then I would be happy to make a PR and implement frule and rrule via SVD instead of LU.

Questions:

  • should the det evaluation then also use SVD, or should it use whichever the standard implementation is?
  • should the SVD implementation be only for a specific class of matrices, or general? E.g. I'm unsure about sparse.
  • I noticed that there are rules for svd - I'm guessing this means one will be able to differentate the frule or rrule based on SVD?

cortner avatar Jul 06 '21 17:07 cortner

The rules for determinant y = det(A) are ultimately due to Giles, who implicitly assumed A is nonsingular. The forward-mode rule is ẏ = y * tr(A \ Ȧ). We can simplify this a bit by noting that det(A) * inv(A) = transpose(C) , where C is the cofactor matrix, giving us the general forward-mode rule ẏ = tr(transpose(C) * Ȧ). This rule is stable for all A, since C only consists of products and sums.

As an example, the cofactor matrix for a 2x2 matrix A = [a b; c d] would be C = [d -c; -b a], so the forward-mode rule is ẏ = tr(transpose(C) * Ȧ) = ȧ*d + a*ḋ - ḃ*c - b*ċ, which is exactly what we'd get by differentiatiing y = a * d - b * c.

Unlike inv and det, we don't have existing algorithms to compute C from some factorization, but it is encouraging to know that a stable rule for det is possible.

Using the SVD decomposition A = U*Diagonal(S)*V', the forward mode rule is ẏ = sum(diag(U' * Ȧ * V) .* (y ./ S)). If exactly one singular value is zero, then this simplifies to ẏ = @views dot(U[:,end], Ȧ, V[:,end]) * prod(S[1:end-1]). The trouble is that SVD is both much slower and less accurate than LU. So such singular matrices in general will have a very tiny singular value, rather than a zero one, so you end up with cancellation error and not a very accurate rule.

But we can get the same thing using the LU decomposition A = L*U by just replacing any zeros on the diagonal of U with eps()s. The fact this works indicates that there's a cancellation of tiny numbers happening in the LU rule, and if we can identify where that is, we can reparameterize to symbolically cancel and have no loss of precision without sacrificing speed.

I implemented the LU rules, so I've been looking into this for the LU case. As mentioned in https://github.com/JuliaDiff/ChainRules.jl/issues/456#issuecomment-869165443, I think the right thing to do is to have only det rules for Factorization inputs, not for matrices themselves. This us much more preferable than overriding the primal's method of computing det.

sethaxen avatar Jul 08 '21 20:07 sethaxen

It is easy to make a stable Ddet using the SVD. Yes it is slower, but it is not less accurate, where did you hear that? As a matter of fact, I would expect the SVD to be more accurate/more stable than LU. Cf e.g. here

Computing the SVD is always numerically stable for any matrix, but is typically more expensive than other decompositions

cortner avatar Jul 08 '21 20:07 cortner

Interesting thought about the epsn on the diagonal. But what do you do if the matrix is singular and the LU decomposition won't even compute?

cortner avatar Jul 08 '21 20:07 cortner

The fact this works indicates that there's a cancellation of tiny numbers happening in the LU rule

this is correct and can be checked with a quick round-off error analysis.

cortner avatar Jul 08 '21 20:07 cortner

It is easy to make a stable Ddet using the SVD. Yes it is slower, but it is not less accurate, where did you hear that? As a matter of fact, I would expect the SVD to be more accurate/more stable than LU.

For the purposes of computing the determinant, svd is less accurate than lu, because tiny numerical errors enter into the computation of the singular values, resulting in exactly nonsingular matrices ending up with no zero-valued singular values. e.g.

julia> A = [1 2; 3 6];

julia> iszero(det(lu(A; check=false)))  # check=false silences the error for a "failed" factorization 
true

julia> iszero(prod(svd(A).S))
false

This can happen with lu as well but seems to be less frequent.

Interesting thought about the epsn on the diagonal. But what do you do if the matrix is singular and the LU decomposition won't even compute?

A "failed" LU factorization is still completed, it's just that it can't be used for solving without some special-casing. See http://www.netlib.org/lapack/explore-3.1.1-html/dgetrf.f.html

sethaxen avatar Jul 08 '21 20:07 sethaxen

I came across this paper "On the adjugate matrix" by GW Stewart https://doi.org/10.1016/S0024-3795(98)10098-8, which does an error analysis to show that when working with decompositions of the form A=XDY for well-conditioned X and Y and diagonal D, one can accurately and stably compute the adjugate matrix (transpose of the cofactor matrix) by replacing any 0 elements in the diagonal of D with a small number and computing adj(A)=det(A)inv(A). So e.g. this is completely safe to do for SVD. For LU decomposition, the error analysis would only apply for the completely pivoted form, which we don't have in LinearAlgebra, but he noted that in his tests, doing the same for the partially pivoted form appears to be both accurate and stable.

The way I suggest proceeding is:

  1. Update all factorization frules/rrules to ensure that if invertibility is assumed, then a tiny perturbation is used for singular matrices to avoid simply erroring.
  2. Add rules for det with factorization inputs, which would need to use the same perturbation.
  3. Remove the general purpose det rule. This ensures that the user-chosen factorization for computing det is always used, and we just make sure the factorization and det rules compose correctly in the singular case.

sethaxen avatar Jul 21 '21 13:07 sethaxen

My two cents from the PyTorch end of things.

We implement the algorithm mentioned in the cited paper in PyTorch. It works fine for the first derivative, but for the second derivative sometimes it fails (we compute the second derivative by autodiff the first derivative). https://github.com/pytorch/pytorch/blob/6afb414c21543fb62e2f862268a707657c83cd72/aten/src/ATen/native/BatchLinearAlgebra.cpp#L3748-L3760

The tricky part is to come up with a reasonable \epsilon for the perturbation. We currently use as epsilon the epsilon for the given datatype.

The implementation in PyTorch is not particularly clean (we plan to give it a clean-up sometime in the future), but the checks show that it is correct and pretty stable for a range of singular matrices :)

For the second derivative we implement a "not-so-good" SVD-based algorithm when the matrix is singular. This works in practice, but the algorithm could be very much improved. Now, this is a story for another time :)

lezcano avatar Nov 17 '21 10:11 lezcano

If you change the rules for det, make sure to test them with matrices that have a complex determinant, e.g. a random unitary. The current rule gives wrong results, see #600, and most importantly, the "well-conditioned" matrix that is currently used as a complex-valued test case actually as a real-valued determinant, and thus doesn't catch that bug.

goerz avatar Mar 20 '22 04:03 goerz