Optimisers.jl
Optimisers.jl copied to clipboard
Type instability in `Flux.setup`
using Flux
function test_setup(opt, s)
state = Flux.setup(opt, s)
return state
end
s = Chain(
Dense(2 => 100, softsign),
Dense(100 => 2)
)
opt = Adam(0.1)
@code_warntype test_setup(opt, s) # type unstable
Output:
MethodInstance for GradientFlows.test_setup(::Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
from test_setup(opt, s) @ GradientFlows c:\Users\Math User\.julia\dev\GradientFlows\src\solvers\sbtm.jl:106
Arguments
#self#::Core.Const(GradientFlows.test_setup)
opt::Adam
s::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
state::Any
Body::Any
1 ─ %1 = Flux.setup::Core.Const(Flux.Train.setup)
│ (state = (%1)(opt, s))
└── return state
Julia version 1.9.3 and Flux version 0.14.6:
(@v1.9) pkg> st Flux
Status `C:\Users\Math User\.julia\environments\v1.9\Project.toml`
[587475ba] Flux v0.14.6
setup is defined in Optimisers.jl, and it's inherently type unstable because it uses a cache to detect + handle shared parameters. Usually I would mark this as a WONTFIX, but there might be some fancy method and/or newer version of Julia which lets us make setup more type stable.
Values from the cache are used when an object x is === some previously seen x. They should therefore always have the same type as what init(rule, x) returns. If this type can be inferred, probably we tell the compiler what to expect, and this may make the whole setup type-stable? Haven't tried though.
We could use _return_type or friends to do that, yes. One thing I'd like to try to make that easier is to delegate what Functors.CachedWalk currently does to the callback passed into the maps. Then it should be easier to swap in/out different implementations of caching and memoization by simply switching the callback.
function _setup(rule, x; cache)
if haskey(cache, x)
T1 = Base._return_type(init, Tuple{typeof(rule), typeof(x)})
T2 = Base._return_type(Leaf, Tuple{typeof(rule), T1})
return cache[x]::T2
end
if isnumeric(x)
ℓ = Leaf(rule, init(rule, x))
# as before...
gives
julia> @code_warntype test_setup(opt, s)
MethodInstance for test_setup(::Optimisers.Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
from test_setup(opt, s) @ Main REPL[5]:1
Arguments
#self#::Core.Const(test_setup)
opt::Optimisers.Adam
s::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
state::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
Body::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
1 ─ %1 = Flux.setup::Core.Const(Flux.Train.setup)
│ (state = (%1)(opt, s))
└── return state
julia> @code_warntype Optimisers.setup(opt, s)
MethodInstance for Optimisers.setup(::Optimisers.Adam, ::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}})
from setup(rule::AbstractRule, model) @ Optimisers ~/.julia/dev/Optimisers/src/interface.jl:29
Arguments
#self#::Core.Const(Optimisers.setup)
rule::Optimisers.Adam
model::Chain{Tuple{Dense{typeof(softsign), Matrix{Float32}, Vector{Float32}}, Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}
Locals
tree::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
cache::IdDict{Any, Any}
msg::String
kwargs::@NamedTuple{}
line::Int64
file::String
id::Symbol
logger::Union{Nothing, Base.CoreLogging.AbstractLogger}
_module::Module
group::Symbol
std_level::Base.CoreLogging.LogLevel
level::Base.CoreLogging.LogLevel
Body::NamedTuple{(:layers,), <:Tuple{Tuple{NamedTuple, NamedTuple}}}
1 ── (cache = Optimisers.IdDict())
│ %2 = (:cache,)::Core.Const((:cache,))
│ %3 = Core.apply_type(Core.NamedTuple, %2)::Core.Const(NamedTuple{(:cache,)})
...
Looks like the inference path _return_type uses might not able to work through the recursion? I wonder if we could use a trick like https://github.com/FluxML/Functors.jl/pull/61 to prevent it from bailing.
In the meantime, would it make sense to add a sentence like This function is type-unstable. to the docstring of setup? If I had seen such a sentence in the docstring, it would have saved me a lot of trouble of discovering it for myself.
would it make sense to add a sentence like "This function is type-unstable." to the docstring of
setup?
Yes, probably.
Also to emphasise that the way to deal with this is a function barrier. You run setup exactly once & pass its result to something. If you are running it in a tight loop, you are probably doing it wrong.