Set AD rules
Resolves #92.
Codecov Report
Merging #93 (acc9a90) into master (4967a5f) will decrease coverage by
1.25%. The diff coverage is69.23%.
@@ Coverage Diff @@
## master #93 +/- ##
==========================================
- Coverage 90.96% 89.71% -1.26%
==========================================
Files 11 11
Lines 620 632 +12
==========================================
+ Hits 564 567 +3
- Misses 56 65 +9
| Impacted Files | Coverage Δ | |
|---|---|---|
| src/chainrules.jl | 69.23% <69.23%> (ø) |
|
| src/vectrick.jl | 90.55% <0.00%> (-3.94%) |
:arrow_down: |
Continue to review full report at Codecov.
Legend - Click here to learn more
Δ = absolute <relative> (impact),ø = not affected,? = missing dataPowered by Codecov. Last update 4967a5f...acc9a90. Read the comment docs.
Your code does not seem to work for some examples and gives the wrong result for others.
For example:
gradient((A, B)->sum(A⊗B), A, B)
gradient((A, B)->sum(kron(A,B)), A, B)
Most of our Kronecker functions fall back on regular function Zygote etc should be able to handle fine. It works for logdet but not for tr and sum (which work with the native kron, e.g. gradient((A, B)-> sum(kron(A, B)), A, B). Not sure why or how we can make ChainRulesCore fall back on this underlying code.
Do you have a reference for your gradients?
Your code does not seem to work for some examples and gives the wrong result for others.
You're right, I started with the following loss function:
function loss(A, B, X)
Z = kron(A, B)*X - y
L = 0.5 * tr(Z' * Z)
return L
end
where y has size (1, num_samples).
I wrote kronecker_product_pullback in the rrule with this in mind, but forgot that each sample in y can have a higher dimension.
In test/testchainrules.jl, I make a comparison of Zygote.gradient with hand-written gradients for this trivial case. I do another comparison with kronecker.
I decided to leave similar tests for higher-dimensions, but leave them with @test_broken for now.
I've been experimenting with KroneckerSum as well.
I managed to get the correct values for the pullback:
function ChainRulesCore.frule((_, ΔA, ΔB), ::KroneckerSum, A::AbstractMatrix, B::AbstractMatrix)
Ω = (A ⊕ B)
∂Ω = (ΔA ⊕ ΔB)
return Ω, ∂Ω
end
function ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)
function kronecker_sum_pullback(ΔΩ)
∂A = nB .* A + Diagonal(fill(tr(B), nA))
∂B = nA .* B + Diagonal(fill(tr(A), nB))
return (NO_FIELDS, ∂A, ∂B)
end
return (A ⊕ B), kronecker_sum_pullback
end
nA = 3
nB = 2
Ar = rand(nA,nA)
Br = rand(nB,nB)
Y_lazy, back_lazy = Zygote._pullback(⊕, Ar, Br)
Y, back = Zygote._pullback((x,y) -> kron(x, Diagonal(ones(nB))) + kron(Diagonal(ones(nA)), y), Ar, Br)
julia> back(Y)[2:end] .≈ back_lazy(Y_lazy)[2:end]
(true, true)
Of course, this isn't useful for computing the gradient in more complicated expressions, since ΔΩ is not used in computing either ∂A or ∂B in the rrule.
Note that:
ChainRulesCore.rrule(::typeof(KroneckerSum), A::AbstractMatrix, B::AbstractMatrix)
overwrites
ChainRulesCore.rrule(::typeof(KroneckerProduct), A::AbstractMatrix, B::AbstractMatrix)
Should I use something else instead of ::typeof(KroneckerProduct)/::typeof(KroneckerSum)?
Still stuck on this, why does computing gradients work for logdet but not tr or sum. It should just fall back to the simple shortcuts, for which adjoints already exist?
Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form
Still stuck on this, why does computing gradients work for
logdet
Can you provide a MWE for logdet?
Technically, it only makes sense to define the adjoints for those function where Kronecker provides shortcuts, based on this rule: https://en.wikipedia.org/wiki/Matrix_calculus#Identities_in_differential_form
Maybe I misunderstood, but doesn't this only provide the frules?