Zygote.jl
Zygote.jl copied to clipboard
LowerTriangular causes error
We are trying to optimize over some code which contain something along the lines of:
using LinearAlgebra
using PDMats
A::Matrix{Float64} = ... # some lower triangular matrix we create
PDMat(Cholesky(LowerTriangular(A))) # a fragment of the code used to compute the scalar we want to optimize
And get the error:
MethodError: no method matching LinearAlgebra.LowerTriangular(::NamedTuple{(:data,), Tuple{LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}}})
Closest candidates are:
LinearAlgebra.LowerTriangular(::LinearAlgebra.LowerTriangular) at ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/triangular.jl:21
LinearAlgebra.LowerTriangular(::AbstractMatrix) at ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/triangular.jl:23
LinearAlgebra.LowerTriangular(::ChainRulesCore.AbstractThunk) at ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:67
Stacktrace:
[1] (::Zygote.var"#605#606")(Δ::NamedTuple{(:data,), Tuple{LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}}})
@ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/array.jl:430
[2] (::Zygote.var"#2975#back#607"{Zygote.var"#605#606"})(Δ::NamedTuple{(:data,), Tuple{LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}}})
@ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
[3] Pullback
Unfortunately I don't understand @adjoint and rrules and how to define a pullback correctly. It seems for Hermitian and Symmetric there are special cases to deal with being passed a NamedTuple{(:data,)}. Does something like that need being defined for LowerTriangular / UpperTriangular etc? Honestly I'm totally lost how a NamedTuple ends up being injected here (is that a general thing with structs?).
The strangest thing about all this, is we experience this error in our tests, and this error is generated in the first test set but the second (identical!) test set works perfectly fine. Does anyone have any ideas what might be going on?
CC @lukekh
I think this is a mismatch between "structural" and "natural" gradients, which is one of the motivations for the gradient projection machinery. MWE:
julia> using Zygote, LinearAlgebra
julia> x = UpperTriangular([1 2; 3 4]);
julia> dump(x)
UpperTriangular{Int64, Matrix{Int64}}
data: Array{Int64}((2, 2)) [1 2; 3 4]
julia> gradient(x -> x[1,1], x)[1] # getindex on an AbstractMatrix "natural", then ProjectTo back to subspace
2×2 UpperTriangular{Float64, Matrix{Float64}}:
1.0 0.0
⋅ 0.0
julia> gradient(x -> x.data[1,1], x)[1] # default "structural" representation, as for any struct
(data = [1.0 0.0; 0.0 0.0],)
julia> gradient(x -> UpperTriangular(x)[1,1], x.data)[1]
2×2 UpperTriangular{Float64, Matrix{Float64}}:
1.0 0.0
⋅ 0.0
julia> gradient(x -> UpperTriangular(x).data[1,1], x.data)[1]
ERROR: MethodError: no method matching UpperTriangular(::NamedTuple{(:data,), Tuple{Matrix{Float64}}})
With https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446 :
julia> gradient(x -> x.data[1,1], x)[1] # structural now converted back to natural
2×2 UpperTriangular{Float64, Matrix{Float64}}:
1.0 0.0
⋅ 0.0
julia> gradient(x -> UpperTriangular(x).data[1,1], x.data)[1]
2×2 UpperTriangular{Float64, Matrix{Float64}}:
1.0 0.0
⋅ 0.0
this is a mismatch between "structural" and "natural" gradients
Yeah I see - thanks, I was wondering if it was something like that.
From https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446#issuecomment-1158998523 it seems resolving this is a bigger body of work, perhaps? Do you know if there are any workarounds we could apply in the meantime?
It ought to be possible to hack just one case, I thought Zygote._project(x::UpperTriangular; dx::NamedTuple) = UpperTriangular(dx.data) might work but it seems not to.
OK thanks for the attempt @mcabbott.
I kinda wish I knew where to start - is there some internals documentation for Zygote etc somewhere so I can get a rough picture of how it hangs together?