Ternary multiplication Vector'*Diagonal*Vector has NamedTuple gradient
MWE:
using Zygote
using LinearAlgebra
D = (Diagonal(ones(3)))
V = ones(3)
function f(x)
return V'*D*V + ((V - x)'D*(V - x))
end
gradient(f, V)
gives the error message
MethodError: no method matching +(::LinearAlgebra.Diagonal{Float64, Vector{Float64}}, ::@NamedTuple{diag::Vector{Float64}})
The function `+` exists, but no method is defined for this combination of argument types.
The error is thrown when accumulating from the two occurences of D in f(x). No error is thrown if one occurence is removed. The error can also be removed by adding collect(D) or parantheses around the multiplications to turn them into two binary multiplications.
If I insert @showgrad(D), it shows that the gradient w.r.t. D is a NamedTuple and not a Diagonal.
These things are annoying. We tried to build a general solution, once upon a time...
This could be solved by adding a rule, ideally to ChainRules.jl.
One work-around might to use 3-arg dot instead of this 3-arg * method. Or as you say, to avoid the 3-arg *:
or parantheses around the multiplications to turn them into two binary multiplications.
This is implicitly what your second term is doing, BTW:
julia> Zygote.gradient(D -> V'*D*V, D)[1]
(diag = [1.0, 1.0, 1.0],)
julia> @which V'*D*V
*(x::Adjoint{T, <:AbstractVector} where T, D::Diagonal, y::AbstractVector)
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.6+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:952
julia> Zygote.gradient(D -> V'D*V, D)[1]
3×3 Diagonal{Float64, Vector{Float64}}:
1.0 ⋅ ⋅
⋅ 1.0 ⋅
⋅ ⋅ 1.0
julia> :(V'D*V)
:((V' * D) * V)
julia> @which V'D*V
*(u::Adjoint{<:Number, <:AbstractVector}, v::AbstractVector{<:Number})
@ LinearAlgebra ~/.julia/juliaup/julia-1.11.6+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:478
NamedTuple arises as the gradient of getfield. That is, Zygote is seeing into the code for ::Diagonal which reads out D.diag, and doesn't know this struct is really a matrix, whose gradient can be another Diagonal with the gradient for this field.
For dot there is (I think) a rule, and it makes a gradient:
julia> Zygote.gradient(D -> dot(V, D, V), D)[1]
3×3 Diagonal{Float64, Vector{Float64}}:
1.0 ⋅ ⋅
⋅ 1.0 ⋅
⋅ ⋅ 1.0
I'm not sure why this doesn't make Zygote use it:
julia> ChainRules.rrule(::typeof(*), x::Adjoint{T, <:AbstractVector} where T, D::Diagonal, y::AbstractVector) = ChainRules.rrule(dot, x.parent, D, y)
julia> Zygote.gradient(D -> V'*D*V, D)[1]
(diag = [1.0, 1.0, 1.0],)