Diffractor.jl icon indicating copy to clipboard operation
Diffractor.jl copied to clipboard

Gradients with respect to dictionaries don't work

Open anhinga opened this issue 3 years ago • 6 comments

In Julia 1.7.3:

julia> Zygote.gradient(x -> x["foo"]["bar"]^2, Dict("foo" => Dict("bar" => 5)))
(Dict{Any, Any}("foo" => Dict{Any, Any}("bar" => 10)),)

julia> Diffractor.gradient(x -> x["foo"]["bar"]^2, Dict("foo" => Dict("bar" => 5)))
ERROR: ArgumentError: Tangent for the primal Dict{String, Int64} should be backed by a AbstractDict type, not by NamedTuple{(:vals,), Tuple{Vector{Int64}}}.
Stacktrace:
 [1] _backing_error(P::Type, G::Type, E::Type)
   @ ChainRulesCore C:\Users\anhin\.julia\packages\ChainRulesCore\ctmSK\src\tangent_types\tangent.jl:62
 [2] ChainRulesCore.Tangent{Dict{String, Int64}, NamedTuple{(:vals,), Tuple{Vector{Int64}}}}(backing::NamedTuple{(:vals,), Tuple{Vector{Int64}}})
   @ ChainRulesCore C:\Users\anhin\.julia\packages\ChainRulesCore\ctmSK\src\tangent_types\tangent.jl:33
 [3] (::Diffractor.var"#162#164"{Symbol, DataType})(Δ::Vector{Int64})
   @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:309
 [4] (::Diffractor.EvenOddOdd{1, 1, Diffractor.var"#162#164"{Symbol, DataType}, Diffractor.var"#163#165"{Symbol}})(Δ::Vector{Int64})
   @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:289
 [5] ∂⃖¹₁getproperty
   @ .\none:1

and

julia> pars = Dict(:x=>0f0, "y"=>4f0, 8=>-3f0)
Dict{Any, Float32} with 3 entries:
  "y" => 4.0
  8   => -3.0
  :x  => 0.0

julia> function my_map(my_f, my_dict)
                  new_dict = Dict()
                  for k in keys(my_dict)
                      new_dict[k] = my_f(my_dict[k])
                  end
                  new_dict
              end
my_map (generic function with 1 method)

julia> function my_sum(my_dict)
                  s = 0f0
                  for k in keys(my_dict)
                      s += my_dict[k]
                  end
                  s
              end
my_sum (generic function with 1 method)

julia> my_sum(my_map(x->x^2, pars))
25.0f0

julia> Zygote.gradient(pars -> my_sum(my_map(x->x^2, pars)), pars)
(Dict{Any, Any}("y" => 8.0f0, 8 => -6.0f0, :x => 0.0f0),)

julia> Diffractor.gradient(pars -> my_sum(my_map(x->x^2, pars)), pars)
ERROR: MethodError: no method matching (::Diffractor.∂⃖recurse{1})(::typeof(Core.sizeof), ::Int64)
Closest candidates are:
  (::Diffractor.∂⃖recurse)(::Any...) at C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:406
Stacktrace:
  [1] macro expansion
    @ C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0 [inlined]
  [2] (::Diffractor.∂⃖recurse{1})(::typeof(Core.sizeof), ::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:415
  [3] (::∂⃖{1})(f::typeof(Core.sizeof), args::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
  [4] is_top_bit_set
    @ .\boot.jl:616 [inlined]
  [5] (::∂⃖{1})(f::typeof(Core.is_top_bit_set), args::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
  [6] check_top_bit
    @ .\boot.jl:626 [inlined]
  [7] (::Diffractor.∂⃖recurse{1})(::typeof(Core.check_top_bit), ::Type{UInt64}, ::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
  [8] (::∂⃖{1})(::typeof(Core.check_top_bit), ::Type, ::Vararg{Any})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
  [9] toUInt64
    @ .\boot.jl:737 [inlined]
 [10] (::Diffractor.∂⃖recurse{1})(::typeof(Core.toUInt64), ::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
 [11] (::∂⃖{1})(f::typeof(Core.toUInt64), args::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
 [12] UInt64
    @ .\boot.jl:767 [inlined]
 [13] (::Diffractor.∂⃖recurse{1})(::Type{UInt64}, ::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
 [14] (::∂⃖{1})(f::Type{UInt64}, args::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
 [15] convert
    @ .\number.jl:7 [inlined]
 [16] (::Diffractor.∂⃖recurse{1})(::typeof(convert), ::Type{UInt64}, ::Int64)
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
 [17] (::∂⃖{1})(::typeof(convert), ::Type, ::Vararg{Any})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
 [18] Dict
    @ .\dict.jl:90 [inlined]
 [19] (::Diffractor.∂⃖recurse{1})(args::Type{Dict{Any, Any}})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
 [20] (::∂⃖{1})(::Type{Dict{Any, Any}})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
 [21] Dict
    @ .\dict.jl:118 [inlined]
 [22] (::Diffractor.∂⃖recurse{1})(args::Type{Dict})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
 [23] (::∂⃖{1})(::Type{Dict})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
 [24] my_map
    @ .\REPL[20]:2 [inlined]
 [25] (::Diffractor.∂⃖recurse{1})(::typeof(my_map), ::var"#14#16", ::Dict{Any, Float32})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
 [26] (::∂⃖{1})(::typeof(my_map), ::Function, ::Vararg{Any})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
 [27] #13
    @ .\REPL[24]:1 [inlined]
 [28] (::Diffractor.∂⃖recurse{1})(::var"#13#15", ::Dict{Any, Float32})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:0
 [29] (::∂⃖{1})(f::var"#13#15", args::Dict{Any, Float32})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\stage1\generated.jl:216
 [30] ∂⃖(::Function, ::Vararg{Any})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\interface.jl:25
 [31] (::Diffractor.∇{var"#13#15"})(args::Dict{Any, Float32})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\interface.jl:123
 [32] Diffractor.∇(::Function, ::Dict{Any, Float32})
    @ Diffractor C:\Users\anhin\.julia\packages\Diffractor\WrKGJ\src\interface.jl:130
 [33] top-level scope
    @ REPL[24]:1

anhinga avatar Sep 01 '22 11:09 anhinga

Isn't the issue here that you're mutating new_dict ?

motabbara avatar Jan 27 '23 05:01 motabbara

@motabbara I doubt that, unless Diffractor.jl differs radically from Zygote.jl in this sense (the "Don't Unroll Adjoint: Differentiating SSA-Form Programs", https://arxiv.org/abs/1810.07951, paper explains on page 5 that the requirement of immutability is only applicable to arrays, and not to other data structures; so it's not an accident that this works in Zygote.jl).

Of course, it might be that the requirements need to be stronger for Diffractor, @Keno might be able to shed some light on this...

anhinga avatar Jan 27 '23 05:01 anhinga

My understanding is that Diffractor uses rrules for pullbacks of all primitives. The rrule for setindex! is the catch-all rule that returns nothing. This causes the optic transform to fail.

In contrast, Zygote defines a rule for setindex! for dictionaries using the @adjoint mechanism that doesn't seem portable to other AD packages.

motabbara avatar Jan 27 '23 11:01 motabbara

@motabbara thanks!

anhinga avatar Jan 27 '23 15:01 anhinga

Last time I asked about Dict mutation in Diffractor, it was explicitly noted as not planned. Maybe that's changed since, but if you need it best to look elsewhere.

ToucheSir avatar Jan 27 '23 15:01 ToucheSir

@ToucheSir Ah, ok; so one is supposed to program this in immutable style, if one is using Diffractor...

anhinga avatar Jan 27 '23 15:01 anhinga