Optimisers.jl
Optimisers.jl copied to clipboard
`destructure` doesn't work correctly with certain functors
using Flux
using Functors
using Optimisers
struct Custom
abc::Tuple
end
Functors.@functor Custom (abc,)
function (f::Custom)(x)
x .* f.abc[1] .+ f.abc[2]
end
function Custom(;dim::Int)
abc = (randn(Float32, dim), randn(Float32, dim))
return Custom(abc)
end
model = Flux.Chain(Dense(4, 16), Custom(;dim=16), Dense(16, 4))
p, re = Optimisers.destructure(model)
re(p)(randn(Float32, 4, 1))
leads to
ERROR: LoadError: type Tuple has no field layers
Stacktrace:
[1] getproperty
@ ./Base.jl:42 [inlined]
[2] functor(#unused#::Type{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, c::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
@ Flux ~/.julia/packages/Flux/qAdFM/src/layers/basic.jl:44
[3] _trainable_biwalk(f::Function, x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, aux::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
@ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:94
[4] #fmap#30
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
[5] _rebuild(x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, off::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}, flat::Vector{Float32}, len::Int64; walk::Function, kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:83
[6] _rebuild
@ ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:82 [inlined]
[7] (::Optimisers.Restructure{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}})(flat::Vector{Float32})
@ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:51
Thanks, that's a bug.
It does work on Flux master BTW, as we changed Chain not to have children which aren't fields. But destructure ought to still allow such things.
julia> re(p)(randn(Float32, 4, 1))
4×1 Matrix{Float32}:
-0.9730846
4.3140717
1.6813886
-1.1657915
julia> Functors.functor(model)
((layers = (Dense(4 => 16), Custom((Float32[1.0382574, 2.225525, ...
julia> @which Functors.functor(typeof(model), model) # the default
functor(::Type{<:Chain}, x) in Flux at /Users/me/.julia/packages/Functors/qBIlC/src/functor.jl:23
(jl_eaNYpA) pkg> st Flux
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_eaNYpA/Project.toml`
[587475ba] Flux v0.13.0-DEV `~/.julia/dev/Flux`
What were you thinking for a solution for a quick release? Something like
functor(::Type{<:Chain}, c::Tuple) = c
?
@ToucheSir that specific suggestion leads to:
ERROR: LoadError: type Int64 has no field weight
Stacktrace:
[1] getproperty
@ ./Base.jl:42 [inlined]
[2] functor(#unused#::Type{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, x::Int64)
@ Flux ~/.julia/packages/Functors/qBIlC/src/functor.jl:23
[3] _trainable_biwalk(f::Function, x::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, aux::Int64)
@ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:94
[4] #fmap#30
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
[5] #31
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
[6] #22
@ ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:100 [inlined]
[7] #4
@ ./generator.jl:36 [inlined]
[8] iterate
@ ./generator.jl:47 [inlined]
[9] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}, Base.var"#4#5"{Optimisers.var"#22#23"{Functors.var"#31#32"{typeof(Optimisers.isnumeric), typeof(Optimisers._trainable_biwalk), IdDict{Any, Any}, Functors.NoKeyword, Optimisers.var"#20#21"{Vector{Float32}}}}}})
@ Base ./array.jl:724
[10] map(::Function, ::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}})
@ Base ./abstractarray.jl:2948
[11] _trainmap
@ ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:99 [inlined]
[12] _trainable_biwalk(f::Function, x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, aux::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
@ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:95
[13] fmap(f::Function, x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ys::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}; exclude::typeof(Optimisers.isnumeric), walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:78
[14] _rebuild(x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, off::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}, flat::Vector{Float32}, len::Int64; walk::Function, kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
@ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:83
[15] _rebuild
@ ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:82 [inlined]
[16] (::Optimisers.Restructure{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}})(flat::Vector{Float32})
@ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:51