ChainRulesTestUtils.jl
ChainRulesTestUtils.jl copied to clipboard
test_rrule fails to handle QRCompactWY return types
The QR-factorization returns a QRCompactWY
, which describes two matrices. When testing a custom pullback for the qr-factorization, test_rrule
calls length
on that type which is not well defined.
MWE:
using LinearAlgebra
using Zygote
using ChainRules
using ChainRulesCore
using Random
using Statistics
using FiniteDifferences
using ChainRulesTestUtils
ChainRulesCore.debug_mode() = true
Random.seed!(1234);
function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T}
QR = qr(A)
m, n = size(A)
function qr_pullback(Ȳ::Tangent)
# For square (m=n) or tall and skinny (m >= n), use the rule derived by
# Seeger et al. (2019) https://arxiv.org/pdf/1710.08717.pdf
#
# Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ
#
# where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and
# copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)}
#
# This code is re-used in the wide case and we put it in a separate function.
function qr_pullback_square_deep(Q̄, R̄, A, Q, R)
M = R̄*R' - Q'*Q̄
# M <- copyltu(M)
M = triu(M) + transpose(triu(M,1))
Ā = (Q̄ + Q * M) / R'
end
# For the wide (m < n) case, we implement the rule derived by
# Liao et al. (2019) https://arxiv.org/pdf/1903.09650.pdf
#
# Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄)
# where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n).
# R = (U,V). Both X and U are full rank square matrices.
#
# See also the discussion in https://github.com/JuliaDiff/ChainRules.jl/pull/306
# And https://github.com/pytorch/pytorch/blob/b162d95e461a5ea22f6840bf492a5dbb2ebbd151/torch/csrc/autograd/FunctionsManual.cpp
Q̄ = Ȳ.factors
R̄ = Ȳ.T
Q = QR.Q
R = QR.R
if m ≥ n
Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : @view Q̄[:, axes(Q, 2)]
Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R)
else
# partition A = [X | Y]
# X = A[1:m, 1:m]
Y = A[1:m, m + 1:end]
# partition R = [U | V], and we don't need V
U = R[1:m, 1:m]
if R̄ isa ChainRules.AbstractZero
V̄ = zeros(size(Y))
Q̄_prime = zeros(size(Q))
Ū = R̄
else
# partition R̄ = [Ū | V̄]
Ū = R̄[1:m, 1:m]
V̄ = R̄[1:m, m + 1:end]
Q̄_prime = Y * V̄'
end
Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄
X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U)
Ȳ = Q * V̄
# partition Ā = [X̄ | Ȳ]
Ā = [X̄ Ȳ]
end
return (NoTangent(), Ā)
end
return QR, qr_pullback
end
function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol)
function getproperty_qr_pullback(Ȳ)
# The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see
# R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989).
# Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄
# in the Tangent object.
∂factors = if d === :Q
Ȳ
else
nothing
end
∂T = if d === :R
Ȳ
else
nothing
end
∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T)
return (NoTangent(), ∂F)
end
return getproperty(F, d), getproperty_qr_pullback
end
V = randn((4,4))
test_rrule(qr, V)
Fails with:
Got exception outside of a @test
MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195
This is because length
is not defined for ::LinearAlgebra.QRCompactWY
:
julia> typeof(qr(V))
LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}}
julia> length(qr(V))
ERROR: MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
Closest candidates are:
length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195
...
Stacktrace:
[1] top-level scope
@ REPL[9]:1
I think that's due to the Jacobian being an ill-defined notion when your function is not vector-to-vector. Many backends define it by flattening, we probably should do that in AD too
Hi, I am having a similar issue with the vcat
function, it may also be related to a few issues I've seen raised.
I think I'm following the docs here but:
import AbstractDifferentiation as AD, Zygote
backend = AD.ZygoteBackend()
args = (1, 2)
# This works fine
jac_from_ad = AD.jacobian(backend, (x, y) -> vcat(x,y), args...)
jac_from_zyg = Zygote.jacobian(vcat, args...)
@assert jac_from_ad == jac_from_zyg
# But this does not evaluate
failed = AD.jacobian(backend, vcat, args...)