ChainRulesTestUtils.jl icon indicating copy to clipboard operation
ChainRulesTestUtils.jl copied to clipboard

test_rrule fails to handle QRCompactWY return types

Open rkube opened this issue 3 years ago • 1 comments

The QR-factorization returns a QRCompactWY, which describes two matrices. When testing a custom pullback for the qr-factorization, test_rrule calls length on that type which is not well defined.

MWE:

using LinearAlgebra
using Zygote
using ChainRules
using ChainRulesCore
using Random
using Statistics
using FiniteDifferences

using ChainRulesTestUtils

ChainRulesCore.debug_mode() = true
Random.seed!(1234);

function ChainRules.rrule(::typeof(qr), A::AbstractMatrix{T}) where {T} 
    QR = qr(A)
    m, n = size(A)
    function qr_pullback(Ȳ::Tangent)
        # For square (m=n) or tall and skinny (m >= n), use the rule derived by 
        # Seeger et al. (2019) https://arxiv.org/pdf/1710.08717.pdf
        #   
        # Ā = [Q̄ + Q copyltu(M)] R⁻ᵀ
        #   
        # where copyltU(C) is the symmetric matrix generated from C by taking the lower triangle of the input and
        # copying it to its upper triangle : copyltu(C)ᵢⱼ = C_{max(i,j), min(i,j)}
        #   
        # This code is re-used in the wide case and we put it in a separate function.

        function qr_pullback_square_deep(Q̄, R̄, A, Q, R)
            M = R̄*R' - Q'*Q̄
            # M <- copyltu(M)
            M = triu(M) + transpose(triu(M,1))
            Ā = (Q̄ + Q * M) / R'
        end 

        # For the wide (m < n) case, we implement the rule derived by
        # Liao et al. (2019) https://arxiv.org/pdf/1903.09650.pdf
        #   
        # Ā = ([Q̄ + V̄Yᵀ] + Q copyltu(M)]U⁻ᵀ, Q V̄)
        # where A=(X,Y) is the column-wise concatenation of the matrices X (n*n) and Y(n, m-n).
        #  R = (U,V). Both X and U are full rank square matrices.
        #   
        # See also the discussion in https://github.com/JuliaDiff/ChainRules.jl/pull/306
        # And https://github.com/pytorch/pytorch/blob/b162d95e461a5ea22f6840bf492a5dbb2ebbd151/torch/csrc/autograd/FunctionsManual.cpp 
        Q̄ = Ȳ.factors
        R̄ = Ȳ.T 
        Q = QR.Q
        R = QR.R
        if m ≥ n 
            Q̄ = Q̄ isa ChainRules.AbstractZero ? Q̄ : @view Q̄[:, axes(Q, 2)] 
            Ā = qr_pullback_square_deep(Q̄, R̄, A, Q, R)
        else
            # partition A = [X | Y]
            # X = A[1:m, 1:m]
            Y = A[1:m, m + 1:end]
    
            # partition R = [U | V], and we don't need V
            U = R[1:m, 1:m]
            if R̄ isa ChainRules.AbstractZero
                V̄ = zeros(size(Y))
                Q̄_prime = zeros(size(Q))
                Ū = R̄ 
            else
                # partition R̄ = [Ū | V̄]
                Ū = R̄[1:m, 1:m]
                V̄ = R̄[1:m, m + 1:end]
                Q̄_prime = Y * V̄'
            end 

            Q̄_prime = Q̄ isa ChainRules.AbstractZero ? Q̄_prime : Q̄_prime + Q̄ 

            X̄ = qr_pullback_square_deep(Q̄_prime, Ū, A, Q, U)
            Ȳ = Q * V̄ 
            # partition Ā = [X̄ | Ȳ]
            Ā = [X̄ Ȳ]
        end 
        return (NoTangent(), Ā)
    end 
    return QR, qr_pullback
end


function ChainRulesCore.rrule(::typeof(getproperty), F::LinearAlgebra.QRCompactWY, d::Symbol) 
    function getproperty_qr_pullback(Ȳ)
        # The QR factorization is calculated from `factors` and T, matrices stored in the QRCompactWYQ format, see 
        # R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989).
        # Instead of backpropagating through the factors, we re-use factors to carry Q̄ and T to carry R̄
        # in the Tangent object.
        ∂factors = if d === :Q
            Ȳ
        else
            nothing
        end

        ∂T = if d === :R
            Ȳ
        else
            nothing
        end

        ∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T)
        return (NoTangent(), ∂F)
    end

    return getproperty(F, d), getproperty_qr_pullback
end

V = randn((4,4))
test_rrule(qr, V)

Fails with:

  Got exception outside of a @test
  MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
  Closest candidates are:
    length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
    length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
    length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195

This is because length is not defined for ::LinearAlgebra.QRCompactWY:

julia> typeof(qr(V))
LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}}
julia> length(qr(V))
ERROR: MethodError: no method matching length(::LinearAlgebra.QRCompactWY{Float32, Matrix{Float32}})
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at abstractdict.jl:58
  length(::Union{Adjoint{T, var"#s6"} where var"#s6"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, Diagonal{T, var"#s13"} where var"#s13"<:(StaticArrays.StaticVector{var"#s14", T} where var"#s14"), Hermitian{T, var"#s10"} where var"#s10"<:(StaticArrays.StaticMatrix{var"#s11", var"#s12", T} where {var"#s11", var"#s12"}), LowerTriangular{T, var"#s18"} where var"#s18"<:(StaticArrays.StaticMatrix{var"#s19", var"#s20", T} where {var"#s19", var"#s20"}), Symmetric{T, var"#s7"} where var"#s7"<:(StaticArrays.StaticMatrix{var"#s8", var"#s9", T} where {var"#s8", var"#s9"}), Transpose{T, var"#s1"} where var"#s1"<:Union{StaticArrays.StaticVector{var"#s1", T} where var"#s1", StaticArrays.StaticMatrix{var"#s4", var"#s5", T} where {var"#s4", var"#s5"}}, UnitLowerTriangular{T, var"#s24"} where var"#s24"<:(StaticArrays.StaticMatrix{var"#s25", var"#s26", T} where {var"#s25", var"#s26"}), UnitUpperTriangular{T, var"#s21"} where var"#s21"<:(StaticArrays.StaticMatrix{var"#s22", var"#s23", T} where {var"#s22", var"#s23"}), UpperTriangular{T, var"#s15"} where var"#s15"<:(StaticArrays.StaticMatrix{var"#s16", var"#s17", T} where {var"#s16", var"#s17"}), StaticArrays.StaticVector{var"#s26", T} where var"#s26", StaticArrays.StaticMatrix{var"#s5", var"#s4", T} where {var"#s5", var"#s4"}, StaticArrays.StaticArray{var"#s26", T, N} where {var"#s26"<:Tuple, N}} where T) at /home/rkube/.julia/packages/StaticArrays/uH2MB/src/abstractarray.jl:1
  length(::Union{Adjoint{T, S}, Transpose{T, S}} where {T, S}) at /buildworker/worker/package_linuxppc64le/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/adjtrans.jl:195
  ...
Stacktrace:
 [1] top-level scope
   @ REPL[9]:1

rkube avatar Jun 25 '21 12:06 rkube

I think that's due to the Jacobian being an ill-defined notion when your function is not vector-to-vector. Many backends define it by flattening, we probably should do that in AD too

gdalle avatar Aug 05 '23 15:08 gdalle

Hi, I am having a similar issue with the vcat function, it may also be related to a few issues I've seen raised.

I think I'm following the docs here but:

import AbstractDifferentiation as AD, Zygote

backend = AD.ZygoteBackend()
args = (1, 2)

# This works fine
jac_from_ad = AD.jacobian(backend, (x, y) -> vcat(x,y), args...)
jac_from_zyg = Zygote.jacobian(vcat, args...)
@assert jac_from_ad == jac_from_zyg

# But this does not evaluate
failed = AD.jacobian(backend, vcat, args...)

olivierlabayle avatar Nov 29 '23 10:11 olivierlabayle