Flux.jl
Flux.jl copied to clipboard
Initialising weights outside of layer declarations
Following the discussions in https://github.com/FluxML/Metalhead.jl/pull/119, I realised that currently there is no way for the user to programmatically pass in weight initialisation strategies for layers in a Chain-like structure based on the type of the layer (after the layer has been declared already, that is). This would be quite the useful feature to have given that many recent models use specific weight initialisations for some types of layers.
An initial idea that I had was to add a mutating version of the existing initialisation functions. Then we could have a wrapper function that mutated the weights of the already existing layer instead of having to copy over an entirely new layer just to change the initial weights. I'm unsure if this clashes with something (and I also don't really have ideas on if there are efficient ways to do this already via existing functionalities), so opening this up for discussion in case there's some conflict before I sit down to write it up.
\cc @darsnack
Mutating variants sound like a good idea. For applying them to initialized models, we could make some wrapper interface like init!(method, model), but that would require defining a init!(..., ::Dense), init!(..., ::Chain), etc. I was thinking instead, we could fmap the mutating variant on a given model's trainable parameters. The main issue here is that usually one wants to initialize a weight parameter but not the bias.
One way with fmap is:
julia> m = Chain(Dense(2,3));
julia> fmap(m; exclude = x -> hasproperty(x, :weight)) do x
x.weight .= (1:3)
x
end
Chain(
Dense(2, 3), # 9 parameters
)
julia> m.layers[1].weight
3×2 Matrix{Float32}:
1.0 1.0
2.0 2.0
3.0 3.0
julia> m.layers[1].bias
3-element Vector{Float32}:
0.0
0.0
0.0
Is going by field name good enough? It might be. Could be wrapped up something like weight!(init!, m, fields = :weight), perhaps, in case you want to specify others?
It may also not be worth the hassle making this mutate, since it will only run once. Maybe the fmap needed to reconstruct the layers is slightly more tricky, but should be doable, and could use all existing init functions.
Yeah my concern was relying on a particular field. We could always make reinit!(method, model) default to fmap on all trainables, then allow custom overrides. Similar approach to #1875.
My fear is that creating a single function for re-init would be too niche for the reasons discussed already (e.g. different parameters in the same layer wanting different init functions). Mutating variants of init functions makes sense to me, however. They'll at least allow users to do things manually until we can think of good higher-level APIs.
I'd like to avoid adding another special function you have to remember to overload for any new layer, so that other people can re-weight it. My sketch above is much too rough, but can there be some nice API a bit like that?
Most layers call bias bias, so filtering based on that might not be terrible. Maybe the reweight! function takes field names to ignore (default (:b, :bias)) and to act on (default every other trainable array); then you can target your custom layer?
If it targets trainable, should it live in Optimisers.jl?
Most layers call bias bias, so filtering based on that might not be terrible. Maybe the reweight! function takes field names to ignore (default (:b, :bias)) and to act on (default every other trainable array); then you can target your custom layer?
I like this. Let's write it so that the keys to ignore are a nested NamedTuple of the same structure as the model (with nothing for "don't descend this branch"). It's easy enough to go from a plain vector ignores = [:b, :bias] to this nested structure (i.e. fmapstructure(x -> ignores, model)). But the core being nested means we allow branches to have separate, overlapping ignore patterns.
If it targets trainable, should it live in Optimisers.jl?
My thought it no, Flux will depend on Optimisers, so it can still live here. Initialization is specific to neural network models and not optimization.
So long as we are only doing mutable models, the easy way to apply this only to some branch is probably something like reweight!(m.layers.enc, glorot_normal; skip = (:b, :bias)). That is, instead of building an API for specifying what branch of the model to act on, just pass that branch in.
True, that's better!
Small sidetone: I would make the initialization method the first arg to support the do syntax in case anyone needs it.
reweight! would have to return a changed model in order to handle bias=false, or do we not care about those?
The semantic definition of bias=false means that trying to load a numeric value into it is ignored. I think that extends to reweight! too.
Indeed [re argument order]. I guess the next question is what gets passed to that function. Should this work, or should it get the size?
reinit!(model) do x
x isa AbstractVector && return x
randn(Float32, size(x))
end
Is what it returns (like here) always copied back into the old array, or only if you do it? I presume it should return a re-built model alla fmap, but does it guarantee that the old one matches?
Is there a way to get Functors to only "see" down to a certain level? If fmap can somehow be overloaded to stop at the Flux layer level (for custom layers, I reckon then it means stopping when a struct is found? Not sure how Flux recognises those), then instead of passing a skip-list for params, we could just leave it to the user to define behaviour for parameters they want to re-initialise (somewhat like PyTorch, whose behaviour I found quite intuitive in this case). First define an _init_weights! function that takes care of the necessary behaviour:
function _init_weights!(m)
if m isa Conv
m.weight .*= 2
m.bias .+= 5
end
return m
end
Now all that is required is a recursive function (fmap-like, or on the torch side of things, like apply) that can walk through the model and apply this function. I was trying to get this to happen but I couldn't figure out how to get Functors to stop at the Flux layer level - is there a simple way to make this happen?
The exclude kwarg of fmap can be used to stop traversing at any point in the tree. It's set to Functors.isleaf by default, but it's relatively straightforward to write a custom callback:
is_layer_or_leaf(m) = Functors.isleaf(m)
is_layer_or_leaf(::Conv) = true
fmap(_init_weights!, m; exclude=is_layer_or_leaf)
_init_weights! could likewise be written in a dispatch-oriented style.
That's great! I tried something that's a pretty typical usecase and it worked quite well:
julia> is_layer_or_leaf(m) = Functors.isleaf(m)
is_layer_or_leaf (generic function with 1 method)
julia> is_layer_or_leaf(::Conv) = true
is_layer_or_leaf (generic function with 2 methods)
julia> is_layer_or_leaf(::Dense) = true
is_layer_or_leaf (generic function with 3 methods)
julia> l = Chain(Dense(3, 3), Conv((3, 3), 3 => 10))
Chain(
Dense(3 => 3), # 12 parameters
Conv((3, 3), 3 => 10), # 280 parameters
) # Total: 4 arrays, 292 parameters, 1.617 KiB.
julia> function _init_weights!(m::Conv)
m.weight .*= 2
m.bias .+= 5
return m
end
_init_weights! (generic function with 1 method)
julia> function _init_weights!(m::Dense)
m.weight .*= 3
m.bias .+= 4
return m
end
_init_weights! (generic function with 2 methods)
julia> fmap(_init_weights!, l; exclude = is_layer_or_leaf)
Chain(
Dense(3 => 3), # 12 parameters
Conv((3, 3), 3 => 10), # 280 parameters
) # Total: 4 arrays, 292 parameters, 1.617 KiB.
julia> l[1].bias
3-element Vector{Float32}:
4.0
4.0
4.0
julia> l[2].bias
10-element Vector{Float32}:
5.0
5.0
5.0
5.0
5.0
5.0
5.0
5.0
5.0
5.0
If this approach has no problems, then it seems pretty straightforward to define a reinit function that has exclude=is_layer_or_leaf passed to fmap by default. The only problem I can imagine happening will be for layers like LayerNorm, which itself has Flux.Scale as one of its components. Some people may want to consider LayerNorm a leaf and only reinit the explicit Flux.Scale layers, while others may want to reinit all Flux.Scale layers irrespective of whether they're within a LayerNorm or not.
The only problem I can imagine happening will be for layers like
LayerNorm, which itself hasFlux.Scaleas one of its components. Some people may want to considerLayerNorma leaf and onlyreinitthe explicitFlux.Scalelayers, while others may want toreinitallFlux.Scalelayers irrespective of whether they're within aLayerNormor not.
This ambiguity is part of why we don't already have a built-in reinit function, IMO. If we had an option for pre-order traversal like https://chengchingwen.github.io/StructWalk.jl/dev/#StructWalk.prewalk-Tuple{Any,%20Any} or https://fluxml.ai/MacroTools.jl/stable/pattern-matching/#Expression-Walking-1, a user could easily choose whether they want to handle LayerNorm.affine separately or not.
I was trying to give this another go, but I noticed the above example (from here) doesn't work with DenseNet. The error was quite cryptic:
julia> model = DenseNet();
julia> fmap(_init_weights!, model; exclude = is_layer_or_leaf)
ERROR: MethodError: no method matching copyto!(::Bool, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{}, typeof(+), Tuple{Bool, Int64}})
Closest candidates are:
copyto!(::Zygote.Buffer, ::Any) at ~/.julia/packages/Zygote/DkIUK/src/tools/buffer.jl:54
copyto!(::Any, ::Base.Broadcast.Broadcasted{<:StaticArrays.StaticArrayStyle}) at ~/.julia/packages/StaticArrays/G7IlJ/src/broadcast.jl:68
copyto!(::AbstractArray, ::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}) at broadcast.jl:929
...
Stacktrace:
[1] broadcasted
@ ./broadcast.jl:1319 [inlined]
[2] broadcasted
@ ./broadcast.jl:1317 [inlined]
[3] _init_weights!(m::Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})
@ Main ./REPL[9]:3
[4] #fmap#17
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:50 [inlined]
[5] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[6] iterate
@ ./generator.jl:47 [inlined]
[7] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
@ Base ./array.jl:804
[8] collect_similar
@ ./array.jl:713 [inlined]
[9] map
@ ./abstractarray.jl:2976 [inlined]
[10] _default_walk
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:43 [inlined]
[11] fmap(f::typeof(_init_weights!), x::Vector{Any}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[12] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Vector{Any})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[13] map
@ ./tuple.jl:273 [inlined]
[14] map(::Function, ::NamedTuple{(:layers,), Tuple{Vector{Any}}})
@ Base ./namedtuple.jl:218
[15] _default_walk(f::Function, x::Chain{Vector{Any}})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
[16] fmap(f::typeof(_init_weights!), x::Chain{Vector{Any}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[17] #18
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:50 [inlined]
[18] map
@ ./tuple.jl:274 [inlined]
[19] _default_walk
@ ~/.julia/packages/Functors/qBIlC/src/functor.jl:43 [inlined]
[20] fmap(f::typeof(_init_weights!), x::Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[21] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[22] map
@ ./tuple.jl:273 [inlined]
[23] map(::Function, ::NamedTuple{(:layers,), Tuple{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}})
@ Base ./namedtuple.jl:218
[24] _default_walk(f::Function, x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
[25] fmap(f::typeof(_init_weights!), x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[26] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[27] map
@ ./tuple.jl:273 [inlined]
[28] map(::Function, ::NamedTuple{(:layers,), Tuple{Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}}})
@ Base ./namedtuple.jl:218
[29] _default_walk(f::Function, x::DenseNet)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
[30] fmap(f::typeof(_init_weights!), x::DenseNet; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
@ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
[31] top-level scope
@ REPL[16]:1
[32] top-level scope
@ ~/.julia/packages/CUDA/GGwVa/src/initialization.jl:52
Am I missing something here? Why isn't this working the way it's supposed to?
Most likely you are trying to accumulate into a bias=false field. Bool (if not all non-array) params are probably safe to ignore when re-initializing, but at some point (probably out of scope for now) we'd want to consider immutable arrays like SArray as well. Those would require returning an updated layer from _init_weights! much like Optimisers.update! does now with state + gradients.