ChainRules.jl
ChainRules.jl copied to clipboard
cholesky decomposition
I'm getting incorrect results when working with the rrule for cholesky where A <: LinearAlgebra.HermOrSym
Passing the input matrix through Matrix fixes the issue. The mul! fix relates to this issue.
using Zygote, ChainRules,LinearAlgebra
# Example matrix
A = [2. -1. 0.0; -1. 2. -1.; 0. -1. 2. ]
import LinearAlgebra.mul!
LinearAlgebra.mul!(C, ::ChainRulesCore.ZeroTangent, ::Any, ::Any, b) = C *=b
# produces zeros
Zygote.jacobian(a -> cholesky(Hermitian(a)).L , A)[1]
# both of the following produce expected result
Zygote.jacobian(a -> cholesky(Matrix(Hermitian(a))).L , A)[1]
Zygote.jacobian(a -> cholesky(a).L , A)[1]
huh, I really though that definition for mul! would fix things correctly.
I will have to look closely,
Oh.
I think it needs a .
LinearAlgebra.mul!(C, ::ChainRulesCore.ZeroTangent, ::Any, ::Any, b) = C .*=b
though that seems likely to be less efficient for bools (thyough it might still optimize out, so it might need a if)