Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

LowerTriangular causes error

Open andyferris opened this issue 2 years ago • 4 comments

We are trying to optimize over some code which contain something along the lines of:

using LinearAlgebra
using PDMats

A::Matrix{Float64} = ... # some lower triangular matrix we create

PDMat(Cholesky(LowerTriangular(A))) # a fragment of the code used to compute the scalar we want to optimize

And get the error:

  MethodError: no method matching LinearAlgebra.LowerTriangular(::NamedTuple{(:data,), Tuple{LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}}})
  Closest candidates are:
    LinearAlgebra.LowerTriangular(::LinearAlgebra.LowerTriangular) at ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/triangular.jl:21
    LinearAlgebra.LowerTriangular(::AbstractMatrix) at ~/.julia/juliaup/julia-1.8.5+0.x64.linux.gnu/share/julia/stdlib/v1.8/LinearAlgebra/src/triangular.jl:23
    LinearAlgebra.LowerTriangular(::ChainRulesCore.AbstractThunk) at ~/.julia/packages/ChainRulesCore/a4mIA/src/tangent_types/thunks.jl:67
  Stacktrace:
    [1] (::Zygote.var"#605#606")(Δ::NamedTuple{(:data,), Tuple{LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}}})
      @ Zygote ~/.julia/packages/Zygote/g2w9o/src/lib/array.jl:430
    [2] (::Zygote.var"#2975#back#607"{Zygote.var"#605#606"})(Δ::NamedTuple{(:data,), Tuple{LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}}})
      @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
    [3] Pullback

Unfortunately I don't understand @adjoint and rrules and how to define a pullback correctly. It seems for Hermitian and Symmetric there are special cases to deal with being passed a NamedTuple{(:data,)}. Does something like that need being defined for LowerTriangular / UpperTriangular etc? Honestly I'm totally lost how a NamedTuple ends up being injected here (is that a general thing with structs?).

The strangest thing about all this, is we experience this error in our tests, and this error is generated in the first test set but the second (identical!) test set works perfectly fine. Does anyone have any ideas what might be going on?

CC @lukekh

andyferris avatar Feb 28 '23 03:02 andyferris

I think this is a mismatch between "structural" and "natural" gradients, which is one of the motivations for the gradient projection machinery. MWE:

julia> using Zygote, LinearAlgebra

julia> x = UpperTriangular([1 2; 3 4]);

julia> dump(x)
UpperTriangular{Int64, Matrix{Int64}}
  data: Array{Int64}((2, 2)) [1 2; 3 4]

julia> gradient(x -> x[1,1], x)[1]  # getindex on an AbstractMatrix "natural", then ProjectTo back to subspace
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 1.0  0.0
  ⋅   0.0

julia> gradient(x -> x.data[1,1], x)[1]  # default "structural" representation, as for any struct 
(data = [1.0 0.0; 0.0 0.0],)

julia> gradient(x -> UpperTriangular(x)[1,1], x.data)[1]
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 1.0  0.0
  ⋅   0.0

julia> gradient(x -> UpperTriangular(x).data[1,1], x.data)[1] 
ERROR: MethodError: no method matching UpperTriangular(::NamedTuple{(:data,), Tuple{Matrix{Float64}}})

With https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446 :

julia> gradient(x -> x.data[1,1], x)[1]  # structural now converted back to natural
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 1.0  0.0
  ⋅   0.0

julia> gradient(x -> UpperTriangular(x).data[1,1], x.data)[1]
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 1.0  0.0
  ⋅   0.0

mcabbott avatar Mar 01 '23 04:03 mcabbott

this is a mismatch between "structural" and "natural" gradients

Yeah I see - thanks, I was wondering if it was something like that.

From https://github.com/JuliaDiff/ChainRulesCore.jl/pull/446#issuecomment-1158998523 it seems resolving this is a bigger body of work, perhaps? Do you know if there are any workarounds we could apply in the meantime?

andyferris avatar Mar 01 '23 05:03 andyferris

It ought to be possible to hack just one case, I thought Zygote._project(x::UpperTriangular; dx::NamedTuple) = UpperTriangular(dx.data) might work but it seems not to.

mcabbott avatar Mar 01 '23 13:03 mcabbott

OK thanks for the attempt @mcabbott.

I kinda wish I knew where to start - is there some internals documentation for Zygote etc somewhere so I can get a rough picture of how it hangs together?

andyferris avatar Mar 01 '23 22:03 andyferris