AD rules that apply to KroneckerProducts
(Related to #11)
I'm trying to wrap my head around getting gradients with kron/kronecker.
- Is it sufficient to define custom AD rules for the vec-trick with ChainRulesCore.jl
function rrule(::typeof(*), K::KroneckerProduct, x::AbstractVector)
function times_vec_pullback(ΔΩ)
...
end
return K*x, times_vec_pullback
end
function rrule(::typeof(*), K::KroneckerProduct, X::AbstractMatrix)
function times_mat_pullback(ΔΩ)
...
end
return K*X, times_mat_pullback
end
- Do we also need to define rules for the constructor as well to get gradients?
function rrule(::typeof(kronecker), A::AbstractMatrix, B::AbstractMatrix)
function kronecker_pullback(ΔΩ)
...
end
return kronecker(A, B), kronecker_pullback
end
- Should the pullbacks also be lazy? I found this to be a decent overview on finding vectorized derivatives. Would the pullbacks then just be reshape rules for these vectorized derivatives?
The question might be what is fixed and what you might want to compute the derivative of. I originally conceived Kronecker to work with systems as f(K * w) where you might want to optimize w as a parameter matrix. This should be easy enough.
Taking the gradients of the Kronecker matrix itself would be a A ⊗ B => I ⊗ B and A ⊗ I.
Maybe the dot(x, A, y) might also be a special case?
I have been working with ChainRulesCore, so you might open a PR and we can look together?
Taking the gradients of the Kronecker matrix itself would be a
A ⊗ B=>I ⊗ BandA ⊗ I.
It doesn't appear to be quite that straight forward. Care must be taken on setting the appropriate size of I for each partial derivative. A (conjugate?) transpose needs to be take of some of the matrices.
I have been working with ChainRulesCore, so you might open a PR and we can look together?
I've managed to put together a semi-working example with the eager kron and Zygote.gradient. I'd have to review how I do the first steps with the chain-rule. I'll open a PR today.
using LinearAlgebra
using Random
using Zygote
M, N = 3, 2
n_samples = 3
Random.seed!(0)
A = rand(1, N)
B = rand(1, M)
x = rand(M*N, n_samples)
y = rand(n_samples)
model(A, B, X) = kron(A, B) * X
function loss(A, B, X)
Z = model(A, B, X) - y'
L = 0.5 * Z * Z'
return L[1]
end
function gradient_A(A, B, x)
Z = model(A, B, x) - y'
n = size(A, 2)
IA_col = Diagonal(ones(n))
return Z * (kron(IA_col', B) * x)'
end
function gradient_B(A, B, x)
Z = model(A, B, x) - y'
n = size(B, 2)
IB_col = Diagonal(ones(n))
return Z * (kron(A, IB_col) * x)'
end
# Compare hand-written gradients with running Zygote.gradient on the loss function
@assert gradient_A(A, B, x) ≈ gradient(loss, A, B, x)[1]
@assert gradient_B(A, B, x) ≈ gradient(loss, A, B, x)[2]
# Show partial derivatives of the loss function w.r.t. to the Kronecker-factors.
@show gradient(loss, A, B, x)[1:2]
Maybe the dot(x, A, y) might also be a special case?
What did you have in mind for this?