Zygote.jl
Zygote.jl copied to clipboard
error with map over NamedTuple
julia> x = (; a=1, b=2)
(a = 1, b = 2)
julia> map(sqrt, x)
(a = 1.0, b = 1.4142135623730951)
julia> gradient(x -> map(sqrt, x).a, x)
ERROR: MethodError: no method matching lastindex(::Nothing)
Closest candidates are:
lastindex(::Any, ::Any) at abstractarray.jl:348
lastindex(::Union{DataStructures.SortedDict, DataStructures.SortedMultiDict, DataStructures.SortedSet}) at /home/carlo/.julia/packages/DataStructures/nBjdy/src/tokens2.jl:19
lastindex(::Union{ArrayInterface.BidiagonalIndex, ArrayInterface.TridiagonalIndex, ArrayInterface.BandedBlockBandedMatrixIndex, ArrayInterface.BandedMatrixIndex, ArrayInterface.BlockBandedMatrixIndex}) at /home/carlo/.julia/packages/ArrayInterface/61qJ7/src/array_index.jl:208
...
Stacktrace:
[1] last(a::Nothing)
@ Base ./abstractarray.jl:437
[2] (::Zygote.var"#568#574")(::Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:211
[3] map
@ ./tuple.jl:233 [inlined]
[4] (::Zygote.var"#567#573"{typeof(sqrt), Tuple{Tuple{Int64, Int64}}, Tuple{Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}})(Δ::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/array.jl:211
[5] (::Zygote.var"#2616#back#577"{Zygote.var"#567#573"{typeof(sqrt), Tuple{Tuple{Int64, Int64}}, Tuple{Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}})(Δ::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[6] (::Zygote.var"#213#214"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2616#back#577"{Zygote.var"#567#573"{typeof(sqrt), Tuple{Tuple{Int64, Int64}}, Tuple{Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}}})(Δ::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/lib/lib.jl:203
[7] (::Zygote.var"#1754#back#215"{Zygote.var"#213#214"{Tuple{Tuple{Nothing}, Tuple{Nothing}}, Zygote.var"#2616#back#577"{Zygote.var"#567#573"{typeof(sqrt), Tuple{Tuple{Int64, Int64}}, Tuple{Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}, Tuple{Float64, Zygote.ZBack{ChainRules.var"#sqrt_pullback#1146"{Float64, ChainRulesCore.ProjectTo{Float64, NamedTuple{(), Tuple{}}}}}}}}}}})(Δ::Tuple{Float64, Nothing})
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[8] Pullback
@ ./namedtuple.jl:197 [inlined]
[9] (::typeof(∂(map)))(Δ::NamedTuple{(:a, :b), Tuple{Float64, Nothing}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface2.jl:0
[10] Pullback
@ ./REPL[60]:1 [inlined]
[11] (::Zygote.var"#50#51"{typeof(∂(#53))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:41
[12] gradient(f::Function, args::NamedTuple{(:a, :b), Tuple{Int64, Int64}})
@ Zygote ~/.julia/packages/Zygote/ajuwN/src/compiler/interface.jl:76
[13] top-level scope
@ REPL[60]:1
[14] top-level scope
@ ~/.julia/packages/CUDA/9T5Sq/src/initialization.jl:66
This now seems to work, maybe should become a test:
julia> x = (; a=1, b=2)
(a = 1, b = 2)
julia> map(sqrt, x)
(a = 1.0, b = 1.4142135623730951)
julia> gradient(x -> map(sqrt, x).a, x)
((a = 0.5, b = nothing),)
(@v1.9) pkg> st Zygote ChainRules
Status `~/.julia/environments/v1.9/Project.toml`
⌃ [082447d4] ChainRules v1.39.0
[e88e6eb3] Zygote v0.6.41