Optimisers.jl icon indicating copy to clipboard operation
Optimisers.jl copied to clipboard

Type instability in `Flux.setup`

Open Vilin97 opened this issue 2 years ago • 7 comments

using Flux

function test_setup(opt, s)
    state = Flux.setup(opt, s)
    return state
end
s = Chain(
        Dense(2 => 100, softsign),
        Dense(100 => 2)
    )
opt = Adam(0.1)
@code_warntype test_setup(opt, s) # type unstable

Output:

MethodInstance for GradientFlows.test_setup(::Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from test_setup(opt, s) @ GradientFlows c:\Users\Math User\.julia\dev\GradientFlows\src\solvers\sbtm.jl:106
Arguments
  #self#::Core.Const(GradientFlows.test_setup)
  opt::Adam
  s::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  state::Any
Body::Any
1 ─ %1 = Flux.setup::Core.Const(Flux.Train.setup)
│        (state = (%1)(opt, s))
└──      return state

Julia version 1.9.3 and Flux version 0.14.6:

(@v1.9) pkg> st Flux
Status `C:\Users\Math User\.julia\environments\v1.9\Project.toml`
  [587475ba] Flux v0.14.6

Vilin97 avatar Oct 10 '23 18:10 Vilin97

setup is defined in Optimisers.jl, and it's inherently type unstable because it uses a cache to detect + handle shared parameters. Usually I would mark this as a WONTFIX, but there might be some fancy method and/or newer version of Julia which lets us make setup more type stable.

ToucheSir avatar Oct 10 '23 18:10 ToucheSir

Values from the cache are used when an object x is === some previously seen x. They should therefore always have the same type as what init(rule, x) returns. If this type can be inferred, probably we tell the compiler what to expect, and this may make the whole setup type-stable? Haven't tried though.

mcabbott avatar Oct 10 '23 19:10 mcabbott

We could use _return_type or friends to do that, yes. One thing I'd like to try to make that easier is to delegate what Functors.CachedWalk currently does to the callback passed into the maps. Then it should be easier to swap in/out different implementations of caching and memoization by simply switching the callback.

ToucheSir avatar Oct 10 '23 23:10 ToucheSir

function _setup(rule, x; cache)
  if haskey(cache, x)
    T1 = Base._return_type(init, Tuple{typeof(rule), typeof(x)})
    T2 = Base._return_type(Leaf, Tuple{typeof(rule), T1})
    return cache[x]::T2
  end
  if isnumeric(x)
    ℓ = Leaf(rule, init(rule, x))
    # as before...

gives

julia> @code_warntype test_setup(opt, s)
MethodInstance for test_setup(::Optimisers.Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from test_setup(opt, s) @ Main REPL[5]:1
Arguments
  #self#::Core.Const(test_setup)
  opt::Optimisers.Adam
  s::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  state::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
Body::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
1 ─ %1 = Flux.setup::Core.Const(Flux.Train.setup)
│        (state = (%1)(opt, s))
└──      return state

julia> @code_warntype Optimisers.setup(opt, s)
MethodInstance for Optimisers.setup(::Optimisers.Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
  from setup(rule::AbstractRule, model) @ Optimisers ~/.julia/dev/Optimisers/src/interface.jl:29
Arguments
  #self#::Core.Const(Optimisers.setup)
  rule::Optimisers.Adam
  model::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
  tree::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
  cache::IdDict{Any, Any}
  msg::String
  kwargs::@NamedTuple{}
  line::Int64
  file::String
  id::Symbol
  logger::Union{Nothing, Base.CoreLogging.AbstractLogger}
  _module::Module
  group::Symbol
  std_level::Base.CoreLogging.LogLevel
  level::Base.CoreLogging.LogLevel
Body::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
1 ──       (cache = Optimisers.IdDict())
│    %2  = (:cache,)::Core.Const((:cache,))
│    %3  = Core.apply_type(Core.NamedTuple, %2)::Core.Const(NamedTuple{(:cache,)})
...

mcabbott avatar Oct 11 '23 00:10 mcabbott

Looks like the inference path _return_type uses might not able to work through the recursion? I wonder if we could use a trick like https://github.com/FluxML/Functors.jl/pull/61 to prevent it from bailing.

ToucheSir avatar Oct 11 '23 01:10 ToucheSir

In the meantime, would it make sense to add a sentence like This function is type-unstable. to the docstring of setup? If I had seen such a sentence in the docstring, it would have saved me a lot of trouble of discovering it for myself.

Vilin97 avatar Oct 18 '23 05:10 Vilin97

would it make sense to add a sentence like "This function is type-unstable." to the docstring of setup?

Yes, probably.

Also to emphasise that the way to deal with this is a function barrier. You run setup exactly once & pass its result to something. If you are running it in a tight loop, you are probably doing it wrong.

mcabbott avatar Mar 29 '24 16:03 mcabbott