ImplicitDifferentiation.jl
ImplicitDifferentiation.jl copied to clipboard
QR decomposition example
Here is an example of differentiating QR decomposition using this package. Would be nice to have an example for each factorization.
using LinearAlgebra, ImplicitDifferentiation, ComponentArrays, Zygote
function qr_conditions(A, x)
Q, R = x.Q, x.R
return vcat(
vec(UpperTriangular(Q' * Q) + LowerTriangular(R) - I - Diagonal(R)),
vec(Q * R - A),
)
end
function qr_forward(A)
qr_res = qr(A)
Q = copy(qr_res.Q[:, 1:size(A, 2)])
R = qr_res.R
return ComponentVector(; Q, R)
end
diff_qr = ImplicitFunction(qr_forward, qr_conditions)
A = rand(10, 4)
x = diff_qr(A)
J = Zygote.jacobian(diff_qr, A)[1]
Or even better, a package DifferentiableFactorizations.jl.
Solved by https://github.com/mohamed82008/DifferentiableFactorizations.jl