Optimisers.jl icon indicating copy to clipboard operation
Optimisers.jl copied to clipboard

`destructure` doesn't work correctly with certain functors

Open rejuvyesh opened this issue 3 years ago • 3 comments

using Flux
using Functors
using Optimisers

struct Custom
    abc::Tuple
end
Functors.@functor Custom (abc,)

function (f::Custom)(x)
    x .* f.abc[1] .+ f.abc[2]
end

function Custom(;dim::Int)
    abc = (randn(Float32, dim), randn(Float32, dim))
    return Custom(abc)
end

model = Flux.Chain(Dense(4, 16), Custom(;dim=16), Dense(16, 4))
p, re = Optimisers.destructure(model)
re(p)(randn(Float32, 4, 1))

leads to

ERROR: LoadError: type Tuple has no field layers
Stacktrace:
  [1] getproperty
    @ ./Base.jl:42 [inlined]
  [2] functor(#unused#::Type{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}, c::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
    @ Flux ~/.julia/packages/Flux/qAdFM/src/layers/basic.jl:44
  [3] _trainable_biwalk(f::Function, x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, aux::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
    @ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:94
  [4] #fmap#30
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
  [5] _rebuild(x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, off::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}, flat::Vector{Float32}, len::Int64; walk::Function, kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:83
  [6] _rebuild
    @ ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:82 [inlined]
  [7] (::Optimisers.Restructure{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}})(flat::Vector{Float32})
    @ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:51

rejuvyesh avatar Mar 23 '22 22:03 rejuvyesh

Thanks, that's a bug.

It does work on Flux master BTW, as we changed Chain not to have children which aren't fields. But destructure ought to still allow such things.

julia> re(p)(randn(Float32, 4, 1))
4×1 Matrix{Float32}:
 -0.9730846
  4.3140717
  1.6813886
 -1.1657915

julia> Functors.functor(model)
((layers = (Dense(4 => 16), Custom((Float32[1.0382574, 2.225525, ...

julia> @which Functors.functor(typeof(model), model)  # the default
functor(::Type{<:Chain}, x) in Flux at /Users/me/.julia/packages/Functors/qBIlC/src/functor.jl:23

(jl_eaNYpA) pkg> st Flux
Status `/private/var/folders/yq/4p2zwd614y59gszh7y9ypyhh0000gn/T/jl_eaNYpA/Project.toml`
  [587475ba] Flux v0.13.0-DEV `~/.julia/dev/Flux`

mcabbott avatar Mar 23 '22 23:03 mcabbott

What were you thinking for a solution for a quick release? Something like

functor(::Type{<:Chain}, c::Tuple) = c

?

ToucheSir avatar Mar 25 '22 03:03 ToucheSir

@ToucheSir that specific suggestion leads to:

ERROR: LoadError: type Int64 has no field weight
Stacktrace:
  [1] getproperty
    @ ./Base.jl:42 [inlined]
  [2] functor(#unused#::Type{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, x::Int64)
    @ Flux ~/.julia/packages/Functors/qBIlC/src/functor.jl:23
  [3] _trainable_biwalk(f::Function, x::Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, aux::Int64)
    @ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:94
  [4] #fmap#30
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
  [5] #31
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:78 [inlined]
  [6] #22
    @ ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:100 [inlined]
  [7] #4
    @ ./generator.jl:36 [inlined]
  [8] iterate
    @ ./generator.jl:47 [inlined]
  [9] collect(itr::Base.Generator{Base.Iterators.Zip{Tuple{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}}, Base.var"#4#5"{Optimisers.var"#22#23"{Functors.var"#31#32"{typeof(Optimisers.isnumeric), typeof(Optimisers._trainable_biwalk), IdDict{Any, Any}, Functors.NoKeyword, Optimisers.var"#20#21"{Vector{Float32}}}}}})
    @ Base ./array.jl:724
 [10] map(::Function, ::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}, ::NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}})
    @ Base ./abstractarray.jl:2948
 [11] _trainmap
    @ ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:99 [inlined]
 [12] _trainable_biwalk(f::Function, x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, aux::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}})
    @ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:95
 [13] fmap(f::Function, x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, ys::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}; exclude::typeof(Optimisers.isnumeric), walk::Function, cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:78
 [14] _rebuild(x::Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, off::Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}, flat::Vector{Float32}, len::Int64; walk::Function, kw::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:83
 [15] _rebuild
    @ ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:82 [inlined]
 [16] (::Optimisers.Restructure{Chain{Tuple{Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Custom, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}, Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}, NamedTuple{(:abc,), Tuple{Tuple{Int64, Int64}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Int64, Int64, Tuple{}}}}})(flat::Vector{Float32})
    @ Optimisers ~/.julia/packages/Optimisers/UAVzc/src/destructure.jl:51

rejuvyesh avatar Mar 25 '22 21:03 rejuvyesh