ChainRules.jl icon indicating copy to clipboard operation
ChainRules.jl copied to clipboard

Implement QR pullback

Open Kolaru opened this issue 5 years ago • 7 comments

Implement the pullback for the QR decomposition, following:

Walter and Lehmann, 2018, Algorithmic Differentiation of Linear Algebra Functions with Application in Optimum Experimental Design

Kolaru avatar Nov 15 '20 13:11 Kolaru

I wonder if the QR rule implied by Seeger et al in https://arxiv.org/pdf/1710.08717.pdf is more performant than the one in Walter and Lehmann? (they actually define an LQ rule, but the same approach produces a QR rule). The below reimplementation of this PR's qr_rev using this rule seems to outperform qr_rev in a simple benchmark while yielding the same result.

function qr_rev2(QR_::ChainRules.QR_TYPE, Q̄, R̄)
    Q, R = QR_
    Q = Matrix(Q)
    Q̄ = Q̄ isa Zero ? Q̄ : @view Q̄[:, axes(Q, 2)]
    V = R̄*R' - Q'*Q̄
    Ā = (Q̄ + Q * Hermitian(V)) / R'
    return Ā
end

julia> A = randn(4, 4);

julia> F = qr(A);

julia> ΔF = Composite{typeof(F)}(Q = randn(eltype(F.Q), size(Matrix(F.Q))), R = randn(eltype(F.R), size(F.R)));

julia> @btime ChainRules.qr_rev($F, $(ΔF.Q), $(ΔF.R));
  16.374 μs (58 allocations: 7.11 KiB)

julia> @btime qr_rev2($F, $(ΔF.Q), $(ΔF.R));
  4.323 μs (12 allocations: 2.27 KiB)

julia> ChainRules.qr_rev(F, ΔF.Q, ΔF.R) ≈ qr_rev2(F, ΔF.Q, ΔF.R)
true

julia> A = randn(10, 4);

julia> F = qr(A);

julia> ΔF = Composite{typeof(F)}(Q = randn(eltype(F.Q), size(Matrix(F.Q))), R = randn(eltype(F.R), size(F.R)));

julia> @btime ChainRules.qr_rev($F, $(ΔF.Q), $(ΔF.R));
  56.299 μs (134 allocations: 20.86 KiB)

julia> @btime qr_rev2($F, $(ΔF.Q), $(ΔF.R));
  5.112 μs (12 allocations: 3.39 KiB)

julia> ChainRules.qr_rev(F, ΔF.Q, ΔF.R) ≈ qr_rev2(F, ΔF.Q, ΔF.R)
true

Or is the rule you have implemented expected to be more numerically stable?

sethaxen avatar Nov 21 '20 05:11 sethaxen

Or is the rule you have implemented expected to be more numerically stable?

I have no idea. The two references we are using do not directly compare each other, and I do not know how to determine this myself.

Kolaru avatar Nov 22 '20 19:11 Kolaru

Or is the rule you have implemented expected to be more numerically stable?

I have no idea. The two references we are using do not directly compare each other, and I do not know how to determine this myself.

The article I linked in https://github.com/JuliaDiff/ChainRules.jl/pull/306#discussion_r528169013 (https://arxiv.org/pdf/2009.10071.pdf), which covers wide and tall matrices as well, also uses the simpler rule from the Seeger et al paper. While I didn't implement their rules for wide and tall matrices, I ended up using a similar approach for the LU decomposition of wide and tall matrices in #354. For these reasons, I'm thinking the Seeger approach is preferable.

sethaxen avatar Feb 04 '21 21:02 sethaxen

I finally had time to come back to this, and it has been kind of a nightmare, because QR decompositions are represented in a weird way that do not play nicely with the tests and the comparison with FiniteDifferences.

After quite a lot of experimentations, I gave up on trying to make everything work with the default type returned by qr. Instead I define a new custom type ExplicitQR that stores the Q and R matrices explicitely. I use this struct for the tests, and compare the end result with the one from the qr method to ensure the latter is correct as well.

I hope this is sufficient. Otherwise I must admit I am out of idea about what should be done to test qr directly.

I am aware of #469 that has a somewhat different approach. I am not currently sure which is better.

Alos I implemented the algorithm suggested by @sethaxen.

So provided my way of testing is okay, this should be ready.

Kolaru avatar Jul 22 '21 02:07 Kolaru

Thanks, @Kolaru! I'll review #469 first, then this, then I'll recommend how to proceed.

sethaxen avatar Jul 22 '21 20:07 sethaxen

What I like about this approach is that it completely sidesteps much of the complexity of the objects returned by the qr methods, as described in https://github.com/JuliaDiff/ChainRules.jl/pull/469#issuecomment-939147971.

What I don't like is that the object being returned by the rrule is completely different from the one the user requested, which will cause some code that worked before to suddenly fail. For example, this code is totally valid with the qr return values in the std library:

A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)
y = Q*w + Q*v
@assert size(y) == (10,)

This works because AbstractQ objects are treated as thin or full depending on what they're multiplied by, but if Q is dense, then this no longer works.

sethaxen avatar Oct 08 '21 22:10 sethaxen

I tried the following

using ChainRules: rrule, ExplicitQR
using LinearAlgebra

A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)

F, F_pullback = rrule(qr, A)
Q, Q_pullback = rrule(getproperty, F, :Q)
y1, y1_pullback = rrule(*, Q, v)
y2, y2_pullback = rrule(*, Q, w)

ȳ1 = rand(10)
_, Q̄1 = y1_pullback(ȳ1)
_, F̄1 = Q_pullback(Q̄1)
_, Ā1 = F_pullback(F̄1) 

ȳ2 = rand(10)
_, Q̄2 = y2_pullback(ȳ2)
_, F̄2 = Q_pullback(Q̄2)
_, Ā2 = F_pullback(F̄2)

and everything seems to be fine (i.e. nothing error, I haven't tested correctness). Is this the correct way to test your point? Whether this PR or #469 is used, adding proper test for quirks of Q seems unavoidable.

Note that the ExplicitQR object I define is only used in the test (and maybe its definition should be moved to the tests for clarity).

As far as I understand, ChainRules should be able to handle the QR objects properly. The problem I had was to create the tangent objects for the tests. The main issues is that we are taking derivative with respect to fields that are not stored in the object directly[*]. After some tries, I just gave up and sidestepped the whole testing issue, hoping this is still enough to ensure correctness.

[*] The most common error I had was:

ArgumentError: Tangent fields do not match primal fields.
Tangent fields: (:Q,). Primal (LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}}) fields: (:factors, :T)

Using a custom struct with explicit fields Q and R naturally solved it.

Kolaru avatar Oct 09 '21 11:10 Kolaru