ChainRules.jl
ChainRules.jl copied to clipboard
Support for qr decomposition pullback
Added a rrule for the qr deomposition. @sethaxen
I've ported this pullback to CUDA.jl: https://gist.github.com/rkube/b17ef683409d76a3f01bcc590b85de6e Where would be a good place for that code?
pokes @sethaxen (I can't really review this)
pokes @sethaxen (I can't really review this)
Yeah, sorry, I took the deep dive studying the various QR parameterizations a few weeks back in prep for reviewing this but haven't had the chance to yet. Sorry for the delay, @rkube!
So this is a really tricky set of rules to define, perhaps trickier than any of the other rules we have in ChainRules currently. Here are just a few complications:
- The signatures for
qrare all changing with Julia v1.7 (below I use the 1.7 signatures) qrcan produce 4 different types in the standard library, summarized below:
# returns QRCompactWY via LAPACK.geqrt!
qr(A::StridedMatrix{<:BlasFloat}, pivot = NoPivot(); kwargs...)
qr!(A::StridedMatrix{<:BlasFloat}, ::NoPivot; kwargs...)
# returns QR via qrfactUnblocked!
qr(A::AbstractMatrix, pivot = NoPivot())
qr!(A::AbstractMatrix, ::NoPivot)
# returns QRPivoted via qrfactPivotedUnblocked!
qr(A::AbstractMatrix, ::ColumnNorm)
qr!(A::AbstractMatrix, ::ColumnNorm)
# returns SuiteSparse.SPQR.QRSparse
qr(A::SparseMatrixCSC, pivot = NoPivot())
- None of the
QRobjects generate theQmatrix. Instead, they represent it in a compact form, wherefactorscontains Householder reflectors in the strict lower trapezoid, andRin the upper trapezoid. Computing rules in terms of these compact elements is challenging, roughly as challenging as implementing theqrfunctions themselves. - Calling
.Qon one of these factorizations produces aAbstractQ <: AbstractMatrixobject that basically has all of the same fields. TheAbstractQobjects areAbstractMatrixes, which means they by default hit all of ourAbstractMatrixrules and therefore will end up withAbstractMatrixcotangents unless we write customrrules for every function one might call on aQRobject. - The
AbstractQobjects are weird. For annxkmatrixA,size(qr(A).Q) == (n, n). However,Qalso allowed to be multiplied by matrices with size(k, m). So consider code like the following:
A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)
y = Q*w + Q*v
@assert size(y) == (10,)
This is completely allowed, but note that the cotangent of Q will be ∂Q = ∂y * w' + ∂y * v'. This adds two matrices of size (10, 10) and (10, 5), respectively. This addition will be handled by the AD engine and will error, so it's necessary then to use ProjectTo to padd the (10, 5) matrix with zeros to make it (10, 10), but this is very wasteful when dealing with very tall matrices where one may never use its (10, 10) version.
I don't think we can just address a subset of these complications one-at-a-time. Once we start adding rules, which will override AD systems' default behavior of differentiating through the qr! fallback (for operator-overloading ADs), then we will need to have more rules to make sure all of our rules compose nicely. I need to think more if there's a way that this can be handled without a tremendous amount of complication.