Zygote.jl icon indicating copy to clipboard operation
Zygote.jl copied to clipboard

Higher Order Fails if wrapping inputs in a NamedTuple

Open avik-pal opened this issue 3 years ago • 0 comments

MWE:

using Zygote

ps = (weight = randn(Float32, 2, 2),)
x = randn(Float32, 2, 1)
W = ps.weight

# Works
gradient(W) do w
      sum(gradient(x) do  y
            sum(w * y)
      end[1])
end

# Fails
gradient(ps) do p
       sum(gradient(x) do  y
              sum(p.weight * y)
       end[1])
end

Stacktrace:

ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:35
  [2] Pullback
    @ ./essentials.jl:612 [inlined]
  [3] (::typeof(∂(getindex)))(Δ::Nothing)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [4] Pullback
    @ /mnt/julia/packages/Zygote/ytjqm/src/tools/builtins.jl:15 [inlined]
  [5] (::typeof(∂(literal_getindex)))(Δ::Nothing)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [6] Pullback
    @ ./reflection.jl:792 [inlined]
  [7] (::typeof(∂(fieldcount)))(Δ::Nothing)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
  [8] Pullback
    @ ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:195 [inlined]
  [9] (::typeof(∂(canonicalize)))(Δ::Nothing)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [10] Pullback
    @ /mnt/julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:115 [inlined]
 [11] Pullback
    @ /mnt/julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:183 [inlined]
 [12] (::typeof(∂(_project)))(Δ::Nothing)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [13] Pullback
    @ /mnt/julia/packages/Zygote/ytjqm/src/lib/lib.jl:231 [inlined]
 [14] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [16] (::typeof(∂(λ)))(Δ::Nothing)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [17] Pullback
    @ ./REPL[16]:3 [inlined]
 [18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [19] Pullback
    @ /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41 [inlined]
 [20] (::typeof(∂(λ)))(Δ::Tuple{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [21] Pullback
    @ /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76 [inlined]
 [22] (::typeof(∂(gradient)))(Δ::Tuple{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [23] Pullback
    @ ./REPL[16]:2 [inlined]
 [24] (::typeof(∂(#17)))(Δ::Float32)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
 [25] (::Zygote.var"#56#57"{typeof(∂(#17))})(Δ::Float32)
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
 [26] gradient(f::Function, args::NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}})
    @ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
 [27] top-level scope
    @ REPL[16]:1
 [28] top-level scope
    @ /mnt/julia/packages/CUDA/Uurn4/src/initialization.jl:52

avik-pal avatar Apr 17 '22 20:04 avik-pal