Zygote.jl
Zygote.jl copied to clipboard
Higher Order Fails if wrapping inputs in a NamedTuple
MWE:
using Zygote
ps = (weight = randn(Float32, 2, 2),)
x = randn(Float32, 2, 1)
W = ps.weight
# Works
gradient(W) do w
sum(gradient(x) do y
sum(w * y)
end[1])
end
# Fails
gradient(ps) do p
sum(gradient(x) do y
sum(p.weight * y)
end[1])
end
Stacktrace:
ERROR: Can't differentiate foreigncall expression
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] Pullback
@ ./essentials.jl:612 [inlined]
[3] (::typeof(∂(getindex)))(Δ::Nothing)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[4] Pullback
@ /mnt/julia/packages/Zygote/ytjqm/src/tools/builtins.jl:15 [inlined]
[5] (::typeof(∂(literal_getindex)))(Δ::Nothing)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[6] Pullback
@ ./reflection.jl:792 [inlined]
[7] (::typeof(∂(fieldcount)))(Δ::Nothing)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[8] Pullback
@ ~/.julia/packages/ChainRulesCore/RbX5a/src/tangent_types/tangent.jl:195 [inlined]
[9] (::typeof(∂(canonicalize)))(Δ::Nothing)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[10] Pullback
@ /mnt/julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:115 [inlined]
[11] Pullback
@ /mnt/julia/packages/Zygote/ytjqm/src/compiler/chainrules.jl:183 [inlined]
[12] (::typeof(∂(_project)))(Δ::Nothing)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[13] Pullback
@ /mnt/julia/packages/Zygote/ytjqm/src/lib/lib.jl:231 [inlined]
[14] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[16] (::typeof(∂(λ)))(Δ::Nothing)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[17] Pullback
@ ./REPL[16]:3 [inlined]
[18] (::typeof(∂(λ)))(Δ::Tuple{Nothing, FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[19] Pullback
@ /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41 [inlined]
[20] (::typeof(∂(λ)))(Δ::Tuple{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[21] Pullback
@ /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76 [inlined]
[22] (::typeof(∂(gradient)))(Δ::Tuple{FillArrays.Fill{Float32, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}})
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[23] Pullback
@ ./REPL[16]:2 [inlined]
[24] (::typeof(∂(#17)))(Δ::Float32)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface2.jl:0
[25] (::Zygote.var"#56#57"{typeof(∂(#17))})(Δ::Float32)
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface.jl:41
[26] gradient(f::Function, args::NamedTuple{(:weight, :bias), Tuple{Matrix{Float32}, Matrix{Float32}}})
@ Zygote /mnt/julia/packages/Zygote/ytjqm/src/compiler/interface.jl:76
[27] top-level scope
@ REPL[16]:1
[28] top-level scope
@ /mnt/julia/packages/CUDA/Uurn4/src/initialization.jl:52