ImplicitDifferentiation.jl icon indicating copy to clipboard operation
ImplicitDifferentiation.jl copied to clipboard

QR decomposition example

Open mohdibntarek opened this issue 3 years ago • 1 comments

Here is an example of differentiating QR decomposition using this package. Would be nice to have an example for each factorization.

using LinearAlgebra, ImplicitDifferentiation, ComponentArrays, Zygote

function qr_conditions(A, x)
  Q, R = x.Q, x.R
  return vcat(
    vec(UpperTriangular(Q' * Q) + LowerTriangular(R) - I - Diagonal(R)),
    vec(Q * R - A),
  )
end
function qr_forward(A)
  qr_res = qr(A)
  Q = copy(qr_res.Q[:, 1:size(A, 2)])
  R = qr_res.R
  return ComponentVector(; Q, R)
end

diff_qr = ImplicitFunction(qr_forward, qr_conditions)

A = rand(10, 4)
x = diff_qr(A)

J = Zygote.jacobian(diff_qr, A)[1]

mohdibntarek avatar Jul 17 '22 01:07 mohdibntarek

Or even better, a package DifferentiableFactorizations.jl.

mohdibntarek avatar Jul 17 '22 01:07 mohdibntarek

Solved by https://github.com/mohamed82008/DifferentiableFactorizations.jl

gdalle avatar Aug 16 '22 15:08 gdalle