Optimisers.jl
Optimisers.jl copied to clipboard
`destructure` doesn't work on Dictionaries
destructure uses map, I think from before support for Dict was added elsewhere, hence this fails:
julia> d = Dict(
:a => Dict(
:b => Dict(
:c => 1,
:d => 2,
),
:e => 3,
),
:f => 4,
)
Dict{Symbol, Any} with 2 entries:
:a => Dict{Symbol, Any}(:b=>Dict(:d=>2, :c=>1), :e=>3)
:f => 4
julia> destructure(d)
ERROR: map is not defined on dictionaries
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:35
[2] map(f::Function, ::Dict{Symbol, Any})
@ Base ./abstractarray.jl:3303
[3] (::Optimisers._TrainableStructWalk)(recurse::Function, x::Dict{Symbol, Any})
@ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/destructure.jl:81
[4] (::Functors.ExcludeWalk{…})(::Function, ::Dict{…})
@ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:106
[5] (::Functors.CachedWalk{…})(::Functors.var"#recurse#19"{…}, ::Dict{…})
@ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:146 [inlined]
[6] execute(::Functors.CachedWalk{Functors.ExcludeWalk{…}, Functors.NoKeyword}, ::Dict{Symbol, Any})
@ Functors ~/.julia/packages/Functors/rlD70/src/walks.jl:38
[7] fmap(::Function, ::Dict{…}; exclude::Function, walk::Optimisers._TrainableStructWalk, cache::IdDict{…}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/rlD70/src/maps.jl:11
[8] _flatten(x::Dict{Symbol, Any})
@ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/destructure.jl:69 [inlined]
[9] destructure(x::Dict{Symbol, Any})
@ Optimisers ~/.julia/packages/Optimisers/F7eR3/src/destructure.jl:30
With #174, and in particular the use of mapvalue instead of map, the situation has improved although it is not fixed yet.
julia> d = Dict(
:a => Dict(
:b => Dict(
:c => [1.],
:d => [2.],
),
:e => 3.,
),
:f => [4.],
)
Dict{Symbol, Any} with 2 entries:
:a => Dict{Symbol, Any}(:b=>Dict(:d=>[2.0], :c=>[1.0]), :e=>3.0)
:f => [4.0]
julia> ps, re = destructure(d)
([2.0, 1.0, 4.0], Restructure(Dict, ..., 3))
julia> re(ps)
2-element Vector{Pair{Symbol}}:
:a => Pair{Symbol}[:b => [:d => [2.0], :c => [1.0]], :e => 3.0]
:f => [4.0]