Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Ternary multiplication Vector'*Diagonal*Vector has NamedTuple gradient

Open johroj opened this issue 3 months ago • 1 comments

MWE:

using Zygote
using LinearAlgebra

D = (Diagonal(ones(3)))

V = ones(3)

function f(x)
	return V'*D*V + ((V - x)'D*(V - x))
end
gradient(f, V)

gives the error message

MethodError: no method matching +(::LinearAlgebra.Diagonal{Float64, Vector{Float64}}, ::@NamedTuple{diag::Vector{Float64}})

The function `+` exists, but no method is defined for this combination of argument types.

The error is thrown when accumulating from the two occurences of D in f(x). No error is thrown if one occurence is removed. The error can also be removed by adding collect(D) or parantheses around the multiplications to turn them into two binary multiplications.

If I insert @showgrad(D), it shows that the gradient w.r.t. D is a NamedTuple and not a Diagonal.

johroj avatar Sep 12 '25 12:09 johroj

These things are annoying. We tried to build a general solution, once upon a time...

This could be solved by adding a rule, ideally to ChainRules.jl.

One work-around might to use 3-arg dot instead of this 3-arg * method. Or as you say, to avoid the 3-arg *:

or parantheses around the multiplications to turn them into two binary multiplications.

This is implicitly what your second term is doing, BTW:

julia> Zygote.gradient(D -> V'*D*V, D)[1]
(diag = [1.0, 1.0, 1.0],)

julia> @which V'*D*V
*(x::Adjoint{T, <:AbstractVector} where T, D::Diagonal, y::AbstractVector)
     @ LinearAlgebra ~/.julia/juliaup/julia-1.11.6+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/diagonal.jl:952

julia> Zygote.gradient(D -> V'D*V, D)[1]
3×3 Diagonal{Float64, Vector{Float64}}:
 1.0   ⋅    ⋅ 
  ⋅   1.0   ⋅ 
  ⋅    ⋅   1.0

julia> :(V'D*V)
:((V' * D) * V)

julia> @which V'D*V
*(u::Adjoint{<:Number, <:AbstractVector}, v::AbstractVector{<:Number})
     @ LinearAlgebra ~/.julia/juliaup/julia-1.11.6+0.aarch64.apple.darwin14/share/julia/stdlib/v1.11/LinearAlgebra/src/adjtrans.jl:478

NamedTuple arises as the gradient of getfield. That is, Zygote is seeing into the code for ::Diagonal which reads out D.diag, and doesn't know this struct is really a matrix, whose gradient can be another Diagonal with the gradient for this field.

For dot there is (I think) a rule, and it makes a gradient:

julia> Zygote.gradient(D -> dot(V, D, V), D)[1]
3×3 Diagonal{Float64, Vector{Float64}}:
 1.0   ⋅    ⋅ 
  ⋅   1.0   ⋅ 
  ⋅    ⋅   1.0

I'm not sure why this doesn't make Zygote use it:

julia> ChainRules.rrule(::typeof(*), x::Adjoint{T, <:AbstractVector} where T, D::Diagonal, y::AbstractVector) = ChainRules.rrule(dot, x.parent, D, y)

julia> Zygote.gradient(D -> V'*D*V, D)[1]
(diag = [1.0, 1.0, 1.0],)

mcabbott avatar Sep 14 '25 01:09 mcabbott