Diffractor.jl
Diffractor.jl copied to clipboard
Gradients with respect to dictionaries don't work
In Julia 1.7.3:
julia> Zygote.gradient(x -> x["foo"]["bar"]^2, Dict("foo" => Dict("bar" => 5)))
(Dict{Any, Any}("foo" => Dict{Any, Any}("bar" => 10)),)
julia> Diffractor.gradient(x -> x["foo"]["bar"]^2, Dict("foo" => Dict("bar" => 5)))
ERROR: ArgumentError: Tangent for the primal Dict{String, Int64} should be backed by a AbstractDict type, not by NamedTuple{(:vals,), Tuple{Vector{Int64}}}.
Stacktrace:
[1] _backing_error(P::Type, G::Type, E::Type)
@ ChainRulesCore C:\Users\anhin\.julia\packages\ChainRulesCore\ctmSK\src\tangent_types\tangent.jl:62
[2] ChainRulesCore.Tangent{Dict{String, Int64}, NamedTuple{(:vals,), Tuple{Vector{Int64}}}}(backing::NamedTuple{(:vals,), Tuple{Vector{Int64}}})
@ ChainRulesCore C:\Users\anhin\.julia\packages\ChainRulesCore\ctmSK\src\tangent_types\tangent.jl:33
[3] (::Diffractor.var"#162#164"{Symbol, DataType})(Δ::Vector{Int64})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:309
[4] (::Diffractor.EvenOddOdd{1, 1, Diffractor.var"#162#164"{Symbol, DataType}, Diffractor.var"#163#165"{Symbol}})(Δ::Vector{Int64})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:289
[5] ∂⃖¹₁getproperty
@ .\none:1
and
julia> pars = Dict(:x=>0f0, "y"=>4f0, 8=>-3f0)
Dict{Any, Float32} with 3 entries:
"y" => 4.0
8 => -3.0
:x => 0.0
julia> function my_map(my_f, my_dict)
new_dict = Dict()
for k in keys(my_dict)
new_dict[k] = my_f(my_dict[k])
end
new_dict
end
my_map (generic function with 1 method)
julia> function my_sum(my_dict)
s = 0f0
for k in keys(my_dict)
s += my_dict[k]
end
s
end
my_sum (generic function with 1 method)
julia> my_sum(my_map(x->x^2, pars))
25.0f0
julia> Zygote.gradient(pars -> my_sum(my_map(x->x^2, pars)), pars)
(Dict{Any, Any}("y" => 8.0f0, 8 => -6.0f0, :x => 0.0f0),)
julia> Diffractor.gradient(pars -> my_sum(my_map(x->x^2, pars)), pars)
ERROR: MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.sizeof), ::Int64)
Closest candidates are:
(::Diffractor.∂⃖recurse)(::Any...) at C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:406
Stacktrace:
[1] macro expansion
@ C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0 [inlined]
[2] (::Diffractor.∂⃖recurse{1})(::typeof(Core.sizeof), ::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:415
[3] (::∂⃖{1})(f::typeof(Core.sizeof), args::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[4] is_top_bit_set
@ .\boot.jl:616 [inlined]
[5] (::∂⃖{1})(f::typeof(Core.is_top_bit_set), args::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[6] check_top_bit
@ .\boot.jl:626 [inlined]
[7] (::Diffractor.∂⃖recurse{1})(::typeof(Core.check_top_bit), ::Type{UInt64}, ::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
[8] (::∂⃖{1})(::typeof(Core.check_top_bit), ::Type, ::Vararg{Any})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[9] toUInt64
@ .\boot.jl:737 [inlined]
[10] (::Diffractor.∂⃖recurse{1})(::typeof(Core.toUInt64), ::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
[11] (::∂⃖{1})(f::typeof(Core.toUInt64), args::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[12] UInt64
@ .\boot.jl:767 [inlined]
[13] (::Diffractor.∂⃖recurse{1})(::Type{UInt64}, ::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
[14] (::∂⃖{1})(f::Type{UInt64}, args::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[15] convert
@ .\number.jl:7 [inlined]
[16] (::Diffractor.∂⃖recurse{1})(::typeof(convert), ::Type{UInt64}, ::Int64)
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
[17] (::∂⃖{1})(::typeof(convert), ::Type, ::Vararg{Any})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[18] Dict
@ .\dict.jl:90 [inlined]
[19] (::Diffractor.∂⃖recurse{1})(args::Type{Dict{Any, Any}})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
[20] (::∂⃖{1})(::Type{Dict{Any, Any}})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[21] Dict
@ .\dict.jl:118 [inlined]
[22] (::Diffractor.∂⃖recurse{1})(args::Type{Dict})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
[23] (::∂⃖{1})(::Type{Dict})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[24] my_map
@ .\REPL[20]:2 [inlined]
[25] (::Diffractor.∂⃖recurse{1})(::typeof(my_map), ::var"#14#16", ::Dict{Any, Float32})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
[26] (::∂⃖{1})(::typeof(my_map), ::Function, ::Vararg{Any})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[27] #13
@ .\REPL[24]:1 [inlined]
[28] (::Diffractor.∂⃖recurse{1})(::var"#13#15", ::Dict{Any, Float32})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
[29] (::∂⃖{1})(f::var"#13#15", args::Dict{Any, Float32})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
[30] ∂⃖(::Function, ::Vararg{Any})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\interface.jl:25
[31] (::Diffractor.∇{var"#13#15"})(args::Dict{Any, Float32})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\interface.jl:123
[32] Diffractor.∇(::Function, ::Dict{Any, Float32})
@ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\interface.jl:130
[33] top-level scope
@ REPL[24]:1
Isn't the issue here that you're mutating new_dict ?
@motabbara I doubt that, unless Diffractor.jl differs radically from Zygote.jl in this sense (the "Don't Unroll Adjoint: Differentiating SSA-Form Programs", https://arxiv.org/abs/1810.07951, paper explains on page 5 that the requirement of immutability is only applicable to arrays, and not to other data structures; so it's not an accident that this works in Zygote.jl).
Of course, it might be that the requirements need to be stronger for Diffractor, @Keno might be able to shed some light on this...
My understanding is that Diffractor uses rrules for pullbacks of all primitives. The rrule for setindex! is the catch-all rule that returns nothing. This causes the optic transform to fail.
In contrast, Zygote defines a rule for setindex! for dictionaries using the @adjoint mechanism that doesn't seem portable to other AD packages.
@motabbara thanks!
Last time I asked about Dict mutation in Diffractor, it was explicitly noted as not planned. Maybe that's changed since, but if you need it best to look elsewhere.
@ToucheSir Ah, ok; so one is supposed to program this in immutable style, if one is using Diffractor...