Zygote.jl
Zygote.jl copied to clipboard
Gradient of dictionary doesn't contain the keys with zero gradient
I would expect the gradient of a dictionary to behave like the gradient of a named tuple and contain all of the keys of the original object. For the dict instead, keys with zero gradient (nothing
) are dropped:
julia> loss(model) = sum(abs2, model[:a])
loss (generic function with 1 method)
julia> nt = (a = [1.0,2.0], b = [3.0,4.0], c = 1);
julia> gradient(loss, nt)[1]
(a = [2.0, 4.0], b = nothing, c = nothing)
julia> d = Dict(:a => [1.0,2.0], :b => [3.0,4.0], :c => 1);
julia> gradient(loss, d)[1]
Dict{Any, Any} with 1 entry:
:a => [2.0, 4.0]
Zygote auto-canonicalizes custom structs/NamedTuples but does not do so for Dicts. I don't quite understand the theory behind canonical vs non-canonical tangent types (ChainRules docs talk about it a little here it seems), but one counter-argument would be that Dicts are to NamedTuples what sparse arrays are to dense arrays. In other words, not all indices (keys) in the primal need to be defined in the tangent. Which approach is faster/more ergonomic/more correct? I'm not sure.