Zygote.jl
Zygote.jl copied to clipboard
Error with gradient of function based on Dictionary
Hi,
I encountered the following errors, when working with functions based on Dictionaries, the following are the Minimum Failing Examples (MFEs) and my naive attempts: (They seem to require some methods and adjoint
s for the Base.ValueIterator
type)
using Zygote
function mfe1(x::Vector)
y = x.^2
collection = Dict(:a => x, :b => y)
sum(map(sum,values(collection)))
end
x = rand(3)
Zygote.gradient(mfe1, x)
The above results in the following error:
ERROR: MethodError: no method matching size(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
Closest candidates are:
size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/qr.jl:581
size(::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY, LinearAlgebra.QRPivoted}, ::Integer) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/qr.jl:580
size(::Union{LinearAlgebra.Cholesky, LinearAlgebra.CholeskyPivoted}) at ~/julia-src/julia-1.8.5/share/julia/stdlib/v1.8/LinearAlgebra/src/cholesky.jl:514
...
Stacktrace:
[1] axes
@ ./abstractarray.jl:95 [inlined]
[2] _tryaxes(x::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:188
[3] map
@ ./tuple.jl:221 [inlined]
[4] ∇map(cx::Zygote.Context{false}, f::typeof(sum), args::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:203
[5] _pullback(cx::Zygote.Context{false}, #unused#::typeof(collect), g::Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/array.jl:244
[6] _pullback
@ ./abstractarray.jl:2961 [inlined]
[7] _pullback(::Zygote.Context{false}, ::typeof(map), ::typeof(sum), ::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[8] _pullback
@ ./REPL[2]:4 [inlined]
[9] _pullback(ctx::Zygote.Context{false}, f::typeof(mfe1), args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[10] pullback(f::Function, cx::Zygote.Context{false}, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:44
[11] pullback
@ ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:42 [inlined]
[12] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:96
Since the above asks for a size(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
and realising that the method length(::Base.ValueIterator{Dict{Symbol, Vector{Float64}}})
exists, I tried adding the following method
Base.size(v::Union{Base.KeySet,Base.ValueIterator}) = (length(v.dict),)
which I don't know if it is the right way to go ahead, but, makes the forward mode, I guess, error free, but now the Zygote.gradient
requests for an adjoint
, see the following updated error:
ERROR: Need an adjoint for constructor Base.ValueIterator{Dict{Symbol, Vector{Float64}}}. Gradient is of type Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:330
[3] (::Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[4] Pullback
@ ./abstractdict.jl:48 [inlined]
[5] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[6] Pullback
@ ./abstractdict.jl:48 [inlined]
[7] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[8] Pullback
@ ./abstractdict.jl:131 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}})(Δ::Vector{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[2]:4 [inlined]
[11] (::Zygote.Pullback{Tuple{typeof(mfe1), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(map), typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(sum)}, typeof(sum)}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}, Nothing, false}}}}}}, Zygote.var"#collect_pullback#715"{Zygote.var"#map_back#677"{typeof(sum), 1, Tuple{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{Float64, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}}}}, Nothing}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe1), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.Pullback{Tuple{typeof(map), typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}}, typeof(sum), Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{}}, Zygote.Pullback{Tuple{typeof(convert), Type{typeof(sum)}, typeof(sum)}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.Generator{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, typeof(sum)}, Nothing, false}}}}}}, Zygote.var"#collect_pullback#715"{Zygote.var"#map_back#677"{typeof(sum), 1, Tuple{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Tuple{Tuple{Base.OneTo{Int64}}}, Vector{Tuple{Float64, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}}}}, Nothing}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
independent of the above, having the following alternative MFE,
function mfe2(x::Vector)
y = x.^2
collection = Dict(:a => x, :b => y)
v = vcat(values(collection)...)
sum(v)
end
throws the same Need an adjoint
error as the above:
ERROR: Need an adjoint for constructor Base.ValueIterator{Dict{Symbol, Vector{Float64}}}. Gradient is of type Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] (::Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:330
[3] (::Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[4] Pullback
@ ./abstractdict.jl:48 [inlined]
[5] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[6] Pullback
@ ./abstractdict.jl:48 [inlined]
[7] (::Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[8] Pullback
@ ./abstractdict.jl:131 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}})(Δ::Tuple{FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}, FillArrays.Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[2]:4 [inlined]
[11] (::Zygote.Pullback{Tuple{typeof(mfe2), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64}, Tuple{Int64}}, Val{1}}}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe2), Vector{Float64}}, Tuple{Zygote.Pullback{Tuple{typeof(values), Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}}, Dict{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{typeof(convert), Type{Dict{Symbol, Vector{Float64}}}, Dict{Symbol, Vector{Float64}}}, Tuple{}}, Zygote.var"#2177#back#309"{Zygote.Jnew{Base.ValueIterator{Dict{Symbol, Vector{Float64}}}, Nothing, false}}}}}}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2988#back#777"{Zygote.var"#771#775"{Vector{Float64}}}, Zygote.Pullback{Tuple{Type{Pair}, Symbol, Vector{Float64}}, Tuple{Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Vector{Float64}}, Vector{Float64}}, Tuple{}}, Zygote.ZBack{ChainRules.var"#fieldtype_pullback#421"}, Zygote.var"#2177#back#309"{Zygote.Jnew{Pair{Symbol, Vector{Float64}}, Nothing, false}}, Zygote.Pullback{Tuple{typeof(Core.convert), Type{Symbol}, Symbol}, Tuple{}}}}, Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Int64}, Zygote.ZBack{ChainRules.var"#vcat_pullback#1415"{Tuple{ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}, ChainRulesCore.ProjectTo{AbstractArray, NamedTuple{(:element, :axes), Tuple{ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}, Tuple{Base.OneTo{Int64}}}}}}, Tuple{Tuple{Int64}, Tuple{Int64}}, Val{1}}}}}, Zygote.var"#3863#back#1242"{Zygote.var"#1238#1241"{2, Vector{Float64}}}, Zygote.var"#1892#back#157"{Zygote.var"#153#156"}, Zygote.Pullback{Tuple{Type{Dict}, Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}, Tuple{Zygote.Pullback{Tuple{Type{Dict{Symbol, Vector{Float64}}}, Tuple{Pair{Symbol, Vector{Float64}}, Pair{Symbol, Vector{Float64}}}}, Any}}}, Zygote.Pullback{Tuple{typeof(Base.Broadcast.materialize), Vector{Float64}}, Tuple{}}}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
I would be happy to know, if this is fixable by writing an adjoint
that the error requests or if there is work around for this issue. Thank you!
Just to update, the following variation MWE where we loop over all the keys, is a work around. (So the problem is with the unavailability of rules and methods for Base.ValueIterator
, which is invoked in the above methods)
function mwe(x::Vector)
y = x.^2
collection = Dict(:a => x, :b => y)
s = zero(eltype(x))
for k in keys(collection)
s += sum(collection[k])
end
s
end
x = rand(3)
Zygote.gradient(mwe, x) # works!
Edit: I realised this is not general enough, for example, if each of the value of Dict
has different eltype
, then this is probably not a good idea.
After some trial and error, I have a generic form of the above work around, for which Zygote.gradient
works,
function mwe_generic(x::Vector)
y = x.^2
collection = Dict(:a => x, :b => y)
s = zero(first(values(collection))[1])
for k in keys(collection)
@inbounds s += sum(collection[k])
end
s
end
x = rand(3)
Zygote.gradient(mwe_generic,x) # works! :)
But it is good to have methods
and adjoint
for Base.ValueIterator
for the original MFE to work!
The above workaround unfortunately doesn't work for IdDict
, seems like it is hitting a ccall
which Zygote doesn't propagate through, see the following:
function mfe_IdDict(x::Vector)
y = x.^2
collection = IdDict(:a => x, :b => y)
s = zero(first(values(collection))[1])
for k in keys(collection)
@inbounds s += sum(collection[k])
end
s
end
julia> Zygote.gradient(mfe_IdDict,x)
ERROR: Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_eqtable_nextind), UInt64, svec(Any, UInt64), 0, :(:ccall), %2, %5, %4)).
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] Pullback
@ ./iddict.jl:143 [inlined]
[3] (::Zygote.Pullback{Tuple{typeof(Base._oidd_nextind), Vector{Any}, Int64}, Tuple{Zygote.Pullback{Tuple{typeof(Base.cconvert), Type{UInt64}, Int64}, Tuple{Zygote.ZBack{Zygote.var"#convert_pullback#325"}}}, Zygote.Pullback{Tuple{typeof(reinterpret), Type{Int64}, UInt64}, Tuple{Zygote.Pullback{Tuple{Core.IntrinsicFunction, Type{Int64}, UInt64}, Tuple{Core.IntrinsicFunction}}}}, Zygote.Pullback{Tuple{typeof(Base.unsafe_convert), Type{UInt64}, UInt64}, Tuple{}}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[4] Pullback
@ ./iddict.jl:146 [inlined]
[5] (::Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[6] #287
@ ~/.julia/packages/Zygote/SuKWp/src/lib/lib.jl:206 [inlined]
[7] (::Zygote.var"#2139#back#289"{Zygote.var"#287#288"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.Pullback{Tuple{typeof(iterate), IdDict{Symbol, Vector{Float64}}, Int64}, Any}}})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OgCVT/src/adjoint.jl:71
[8] Pullback
@ ./abstractdict.jl:64 [inlined]
[9] (::Zygote.Pullback{Tuple{typeof(iterate), Base.KeySet{Symbol, IdDict{Symbol, Vector{Float64}}}, Int64}, Any})(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[6]:7 [inlined]
[11] (::Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface2.jl:0
[12] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{typeof(mfe_IdDict), Vector{Float64}}, Any}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:45
[13] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/SuKWp/src/compiler/interface.jl:97
Hi @ToucheSir, are there plans to make Zygote work with IdDict
? (should I open a different issue? I haven't found any IdDict
related issue in issues section here.)
There are no plans to make Zygote work better with any kind of Dict, but only because there is no dev capacity to do so. Hence why I added the above labels. Dicts are perhaps one of the trickiest types to create new functionality/fix bugs for in Zygote, but if any brave soul wants to try I'd be happy to guide them.