ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Support for qr decomposition pullback

Open rkube opened this issue 4 years ago • 4 comments

Added a rrule for the qr deomposition. @sethaxen

rkube avatar Jul 12 '21 19:07 rkube

I've ported this pullback to CUDA.jl: https://gist.github.com/rkube/b17ef683409d76a3f01bcc590b85de6e Where would be a good place for that code?

rkube avatar Aug 25 '21 20:08 rkube

pokes @sethaxen (I can't really review this)

oxinabox avatar Sep 21 '21 17:09 oxinabox

pokes @sethaxen (I can't really review this)

Yeah, sorry, I took the deep dive studying the various QR parameterizations a few weeks back in prep for reviewing this but haven't had the chance to yet. Sorry for the delay, @rkube!

sethaxen avatar Sep 21 '21 18:09 sethaxen

So this is a really tricky set of rules to define, perhaps trickier than any of the other rules we have in ChainRules currently. Here are just a few complications:

  • The signatures for qr are all changing with Julia v1.7 (below I use the 1.7 signatures)
  • qr can produce 4 different types in the standard library, summarized below:
# returns QRCompactWY via LAPACK.geqrt!
qr(A::StridedMatrix{<:BlasFloat}, pivot = NoPivot(); kwargs...)
qr!(A::StridedMatrix{<:BlasFloat}, ::NoPivot; kwargs...)

# returns QR via qrfactUnblocked!
qr(A::AbstractMatrix, pivot = NoPivot())
qr!(A::AbstractMatrix, ::NoPivot)

# returns QRPivoted via qrfactPivotedUnblocked!
qr(A::AbstractMatrix, ::ColumnNorm)
qr!(A::AbstractMatrix, ::ColumnNorm)

# returns SuiteSparse.SPQR.QRSparse
qr(A::SparseMatrixCSC, pivot = NoPivot())
  • None of the QR objects generate the Q matrix. Instead, they represent it in a compact form, where factors contains Householder reflectors in the strict lower trapezoid, and R in the upper trapezoid. Computing rules in terms of these compact elements is challenging, roughly as challenging as implementing the qr functions themselves.
  • Calling .Q on one of these factorizations produces a AbstractQ <: AbstractMatrix object that basically has all of the same fields. The AbstractQ objects are AbstractMatrixes, which means they by default hit all of our AbstractMatrix rules and therefore will end up with AbstractMatrix cotangents unless we write custom rrules for every function one might call on a QR object.
  • The AbstractQ objects are weird. For an nxk matrix A, size(qr(A).Q) == (n, n). However, Q also allowed to be multiplied by matrices with size (k, m). So consider code like the following:
A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)
y = Q*w + Q*v
@assert size(y) == (10,)

This is completely allowed, but note that the cotangent of Q will be ∂Q = ∂y * w' + ∂y * v'. This adds two matrices of size (10, 10) and (10, 5), respectively. This addition will be handled by the AD engine and will error, so it's necessary then to use ProjectTo to padd the (10, 5) matrix with zeros to make it (10, 10), but this is very wasteful when dealing with very tall matrices where one may never use its (10, 10) version.

I don't think we can just address a subset of these complications one-at-a-time. Once we start adding rules, which will override AD systems' default behavior of differentiating through the qr! fallback (for operator-overloading ADs), then we will need to have more rules to make sure all of our rules compose nicely. I need to think more if there's a way that this can be handled without a tremendous amount of complication.

sethaxen avatar Oct 08 '21 22:10 sethaxen