Kronecker.jl icon indicating copy to clipboard operation
Kronecker.jl copied to clipboard

Set AD rules

Open elisno opened this issue 4 years ago • 9 comments

Resolves #92.

elisno avatar Apr 03 '21 00:04 elisno

Codecov Report

Merging #93 (acc9a90) into master (4967a5f) will decrease coverage by 1.25%. The diff coverage is 69.23%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #93      +/-   ##
==========================================
- Coverage   90.96%   89.71%   -1.26%     
==========================================
  Files          11       11              
  Lines         620      632      +12     
==========================================
+ Hits          564      567       +3     
- Misses         56       65       +9     
Impacted Files Coverage Δ
src/chainrules.jl 69.23% <69.23%> (ø)
src/vectrick.jl 90.55% <0.00%> (-3.94%) :arrow_down:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 4967a5f...acc9a90. Read the comment docs.

codecov-io avatar Apr 03 '21 00:04 codecov-io

Your code does not seem to work for some examples and gives the wrong result for others.

For example:

gradient((A, B)->sum(A⊗B), A, B)
gradient((A, B)->sum(kron(A,B)), A, B)

Most of our Kronecker functions fall back on regular function Zygote etc should be able to handle fine. It works for logdet but not for tr and sum (which work with the native kron, e.g. gradient((A, B)-> sum(kron(A, B)), A, B). Not sure why or how we can make ChainRulesCore fall back on this underlying code.

Do you have a reference for your gradients?

MichielStock avatar Apr 06 '21 09:04 MichielStock

Your code does not seem to work for some examples and gives the wrong result for others.

You're right, I started with the following loss function:

function loss(A, B, X)
    Z = kron(A, B)*X - y
    L = 0.5 * tr(Z' * Z)
    return L
end

where y has size (1, num_samples). I wrote kronecker_product_pullback in the rrule with this in mind, but forgot that each sample in y can have a higher dimension.

In test/testchainrules.jl, I make a comparison of Zygote.gradient with hand-written gradients for this trivial case. I do another comparison with kronecker.

I decided to leave similar tests for higher-dimensions, but leave them with @test_broken for now.

elisno avatar Apr 06 '21 14:04 elisno

I've been experimenting with KroneckerSum as well.

I managed to get the correct values for the pullback:

function ChainRulesCore.frule((_, ΔA, ΔB), ::KroneckerSum, A::AbstractMatrix, B::AbstractMatrix)
    Ω = (A ⊕ B)
    ∂Ω = (ΔA ⊕ ΔB)
    return Ω, ∂Ω
end

function ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)
    function kronecker_sum_pullback(ΔΩ)
        ∂A = nB .* A + Diagonal(fill(tr(B), nA))
        ∂B = nA .* B + Diagonal(fill(tr(A), nB))
        return (NO_FIELDS, ∂A, ∂B)
    end
    return (A ⊕ B), kronecker_sum_pullback
end

nA = 3
nB = 2
Ar = rand(nA,nA)
Br = rand(nB,nB)
Y_lazy, back_lazy = Zygote._pullback(⊕, Ar, Br)
Y, back = Zygote._pullback((x,y) -> kron(x, Diagonal(ones(nB))) + kron(Diagonal(ones(nA)), y), Ar, Br)
julia> back(Y)[2:end] .≈ back_lazy(Y_lazy)[2:end]
(true, true)

Of course, this isn't useful for computing the gradient in more complicated expressions, since ΔΩ is not used in computing either ∂A or ∂B in the rrule.

elisno avatar Apr 06 '21 14:04 elisno

Note that:

ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)

overwrites

ChainRulesCore.rrule(::typeof(KroneckerProduct), A::AbstractMatrix, B::AbstractMatrix)

Should I use something else instead of ::typeof(KroneckerProduct)/::typeof(KroneckerSum)?

elisno avatar Apr 06 '21 14:04 elisno

Still stuck on this, why does computing gradients work for logdet but not tr or sum. It should just fall back to the simple shortcuts, for which adjoints already exist?

MichielStock avatar Apr 12 '21 13:04 MichielStock

Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form

MichielStock avatar Apr 12 '21 13:04 MichielStock

Still stuck on this, why does computing gradients work for logdet

Can you provide a MWE for logdet?

elisno avatar Apr 12 '21 13:04 elisno

Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form

Maybe I misunderstood, but doesn't this only provide the frules?

elisno avatar Apr 12 '21 13:04 elisno