Flux.jl icon indicating copy to clipboard operation
Flux.jl copied to clipboard

Initialising weights outside of layer declarations

Open theabhirath opened this issue 3 years ago • 18 comments

Following the discussions in https://github.com/FluxML/Metalhead.jl/pull/119, I realised that currently there is no way for the user to programmatically pass in weight initialisation strategies for layers in a Chain-like structure based on the type of the layer (after the layer has been declared already, that is). This would be quite the useful feature to have given that many recent models use specific weight initialisations for some types of layers.

An initial idea that I had was to add a mutating version of the existing initialisation functions. Then we could have a wrapper function that mutated the weights of the already existing layer instead of having to copy over an entirely new layer just to change the initial weights. I'm unsure if this clashes with something (and I also don't really have ideas on if there are efficient ways to do this already via existing functionalities), so opening this up for discussion in case there's some conflict before I sit down to write it up.

\cc @darsnack

theabhirath avatar Feb 19 '22 01:02 theabhirath

Mutating variants sound like a good idea. For applying them to initialized models, we could make some wrapper interface like init!(method, model), but that would require defining a init!(..., ::Dense), init!(..., ::Chain), etc. I was thinking instead, we could fmap the mutating variant on a given model's trainable parameters. The main issue here is that usually one wants to initialize a weight parameter but not the bias.

darsnack avatar Feb 19 '22 02:02 darsnack

One way with fmap is:

julia> m = Chain(Dense(2,3));

julia> fmap(m; exclude = x -> hasproperty(x, :weight)) do x
         x.weight .= (1:3)
         x
       end
Chain(
  Dense(2, 3),                          # 9 parameters
)

julia> m.layers[1].weight
3×2 Matrix{Float32}:
 1.0  1.0
 2.0  2.0
 3.0  3.0

julia> m.layers[1].bias
3-element Vector{Float32}:
 0.0
 0.0
 0.0

Is going by field name good enough? It might be. Could be wrapped up something like weight!(init!, m, fields = :weight), perhaps, in case you want to specify others?

It may also not be worth the hassle making this mutate, since it will only run once. Maybe the fmap needed to reconstruct the layers is slightly more tricky, but should be doable, and could use all existing init functions.

mcabbott avatar Feb 19 '22 02:02 mcabbott

Yeah my concern was relying on a particular field. We could always make reinit!(method, model) default to fmap on all trainables, then allow custom overrides. Similar approach to #1875.

darsnack avatar Feb 19 '22 14:02 darsnack

My fear is that creating a single function for re-init would be too niche for the reasons discussed already (e.g. different parameters in the same layer wanting different init functions). Mutating variants of init functions makes sense to me, however. They'll at least allow users to do things manually until we can think of good higher-level APIs.

ToucheSir avatar Feb 19 '22 17:02 ToucheSir

I'd like to avoid adding another special function you have to remember to overload for any new layer, so that other people can re-weight it. My sketch above is much too rough, but can there be some nice API a bit like that?

Most layers call bias bias, so filtering based on that might not be terrible. Maybe the reweight! function takes field names to ignore (default (:b, :bias)) and to act on (default every other trainable array); then you can target your custom layer?

If it targets trainable, should it live in Optimisers.jl?

mcabbott avatar Feb 19 '22 18:02 mcabbott

Most layers call bias bias, so filtering based on that might not be terrible. Maybe the reweight! function takes field names to ignore (default (:b, :bias)) and to act on (default every other trainable array); then you can target your custom layer?

I like this. Let's write it so that the keys to ignore are a nested NamedTuple of the same structure as the model (with nothing for "don't descend this branch"). It's easy enough to go from a plain vector ignores = [:b, :bias] to this nested structure (i.e. fmapstructure(x -> ignores, model)). But the core being nested means we allow branches to have separate, overlapping ignore patterns.

If it targets trainable, should it live in Optimisers.jl?

My thought it no, Flux will depend on Optimisers, so it can still live here. Initialization is specific to neural network models and not optimization.

darsnack avatar Feb 19 '22 19:02 darsnack

So long as we are only doing mutable models, the easy way to apply this only to some branch is probably something like reweight!(m.layers.enc, glorot_normal; skip = (:b, :bias)). That is, instead of building an API for specifying what branch of the model to act on, just pass that branch in.

mcabbott avatar Feb 19 '22 19:02 mcabbott

True, that's better!

darsnack avatar Feb 19 '22 19:02 darsnack

Small sidetone: I would make the initialization method the first arg to support the do syntax in case anyone needs it.

darsnack avatar Feb 19 '22 19:02 darsnack

reweight! would have to return a changed model in order to handle bias=false, or do we not care about those?

ToucheSir avatar Feb 19 '22 19:02 ToucheSir

The semantic definition of bias=false means that trying to load a numeric value into it is ignored. I think that extends to reweight! too.

darsnack avatar Feb 19 '22 19:02 darsnack

Indeed [re argument order]. I guess the next question is what gets passed to that function. Should this work, or should it get the size?

reinit!(model) do x
  x isa AbstractVector && return x
  randn(Float32, size(x))
end

Is what it returns (like here) always copied back into the old array, or only if you do it? I presume it should return a re-built model alla fmap, but does it guarantee that the old one matches?

mcabbott avatar Feb 19 '22 19:02 mcabbott

Is there a way to get Functors to only "see" down to a certain level? If fmap can somehow be overloaded to stop at the Flux layer level (for custom layers, I reckon then it means stopping when a struct is found? Not sure how Flux recognises those), then instead of passing a skip-list for params, we could just leave it to the user to define behaviour for parameters they want to re-initialise (somewhat like PyTorch, whose behaviour I found quite intuitive in this case). First define an _init_weights! function that takes care of the necessary behaviour:

function _init_weights!(m)
    if m isa Conv
        m.weight .*= 2
        m.bias .+= 5
    end
    return m
end

Now all that is required is a recursive function (fmap-like, or on the torch side of things, like apply) that can walk through the model and apply this function. I was trying to get this to happen but I couldn't figure out how to get Functors to stop at the Flux layer level - is there a simple way to make this happen?

theabhirath avatar Apr 23 '22 05:04 theabhirath

The exclude kwarg of fmap can be used to stop traversing at any point in the tree. It's set to Functors.isleaf by default, but it's relatively straightforward to write a custom callback:

is_layer_or_leaf(m) = Functors.isleaf(m)
is_layer_or_leaf(::Conv) = true

fmap(_init_weights!, m; exclude=is_layer_or_leaf)

_init_weights! could likewise be written in a dispatch-oriented style.

ToucheSir avatar Apr 23 '22 22:04 ToucheSir

That's great! I tried something that's a pretty typical usecase and it worked quite well:

julia> is_layer_or_leaf(m) = Functors.isleaf(m)
is_layer_or_leaf (generic function with 1 method)

julia> is_layer_or_leaf(::Conv) = true
is_layer_or_leaf (generic function with 2 methods)

julia> is_layer_or_leaf(::Dense) = true
is_layer_or_leaf (generic function with 3 methods)

julia> l = Chain(Dense(3, 3), Conv((3, 3), 3 => 10))
Chain(
  Dense(3 => 3),                        # 12 parameters
  Conv((3, 3), 3 => 10),                # 280 parameters
)                   # Total: 4 arrays, 292 parameters, 1.617 KiB.

julia> function _init_weights!(m::Conv)
           m.weight .*= 2
           m.bias .+= 5
           return m
       end
_init_weights! (generic function with 1 method)

julia> function _init_weights!(m::Dense)
           m.weight .*= 3
           m.bias .+= 4
           return m
       end
_init_weights! (generic function with 2 methods)

julia> fmap(_init_weights!, l; exclude = is_layer_or_leaf)
Chain(
  Dense(3 => 3),                        # 12 parameters
  Conv((3, 3), 3 => 10),                # 280 parameters
)                   # Total: 4 arrays, 292 parameters, 1.617 KiB.

julia> l[1].bias
3-element Vector{Float32}:
 4.0
 4.0
 4.0

julia> l[2].bias
10-element Vector{Float32}:
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0
 5.0

If this approach has no problems, then it seems pretty straightforward to define a reinit function that has exclude=is_layer_or_leaf passed to fmap by default. The only problem I can imagine happening will be for layers like LayerNorm, which itself has Flux.Scale as one of its components. Some people may want to consider LayerNorm a leaf and only reinit the explicit Flux.Scale layers, while others may want to reinit all Flux.Scale layers irrespective of whether they're within a LayerNorm or not.

theabhirath avatar Apr 24 '22 01:04 theabhirath

The only problem I can imagine happening will be for layers like LayerNorm, which itself has Flux.Scale as one of its components. Some people may want to consider LayerNorm a leaf and only reinit the explicit Flux.Scale layers, while others may want to reinit all Flux.Scale layers irrespective of whether they're within a LayerNorm or not.

This ambiguity is part of why we don't already have a built-in reinit function, IMO. If we had an option for pre-order traversal like https://chengchingwen.github.io/StructWalk.jl/dev/#StructWalk.prewalk-Tuple{Any,%20Any} or https://fluxml.ai/MacroTools.jl/stable/pattern-matching/#Expression-Walking-1, a user could easily choose whether they want to handle LayerNorm.affine separately or not.

ToucheSir avatar Apr 24 '22 03:04 ToucheSir

I was trying to give this another go, but I noticed the above example (from here) doesn't work with DenseNet. The error was quite cryptic:

julia> model = DenseNet();

julia> fmap(_init_weights!, model; exclude = is_layer_or_leaf)
ERROR: MethodError: no method matching copyto!(::Bool, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{0}, Tuple{}, typeof(+), Tuple{Bool, Int64}})
Closest candidates are:
  copyto!(::Zygote.Buffer, ::Any) at ~/.julia/packages/Zygote/DkIUK/src/tools/buffer.jl:54
  copyto!(::Any, ::Base.Broadcast.Broadcasted{<:StaticArrays.StaticArrayStyle}) at ~/.julia/packages/StaticArrays/G7IlJ/src/broadcast.jl:68
  copyto!(::AbstractArray, ::Base.Broadcast.Broadcasted{<:Base.Broadcast.AbstractArrayStyle{0}}) at broadcast.jl:929
  ...
Stacktrace:
  [1] broadcasted
    @ ./broadcast.jl:1319 [inlined]
  [2] broadcasted
    @ ./broadcast.jl:1317 [inlined]
  [3] _init_weights!(m::Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})
    @ Main ./REPL[9]:3
  [4] #fmap#17
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:50 [inlined]
  [5] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Conv{2, 2, typeof(identity), Array{Float32, 4}, Bool})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
  [6] iterate
    @ ./generator.jl:47 [inlined]
  [7] _collect(c::Vector{Any}, itr::Base.Generator{Vector{Any}, Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)}}, #unused#::Base.EltypeUnknown, isz::Base.HasShape{1})
    @ Base ./array.jl:804
  [8] collect_similar
    @ ./array.jl:713 [inlined]
  [9] map
    @ ./abstractarray.jl:2976 [inlined]
 [10] _default_walk
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:43 [inlined]
 [11] fmap(f::typeof(_init_weights!), x::Vector{Any}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [12] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Vector{Any})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [13] map
    @ ./tuple.jl:273 [inlined]
 [14] map(::Function, ::NamedTuple{(:layers,), Tuple{Vector{Any}}})
    @ Base ./namedtuple.jl:218
 [15] _default_walk(f::Function, x::Chain{Vector{Any}})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
 [16] fmap(f::typeof(_init_weights!), x::Chain{Vector{Any}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [17] #18
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:50 [inlined]
 [18] map
    @ ./tuple.jl:274 [inlined]
 [19] _default_walk
    @ ~/.julia/packages/Functors/qBIlC/src/functor.jl:43 [inlined]
 [20] fmap(f::typeof(_init_weights!), x::Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [21] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [22] map
    @ ./tuple.jl:273 [inlined]
 [23] map(::Function, ::NamedTuple{(:layers,), Tuple{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}})
    @ Base ./namedtuple.jl:218
 [24] _default_walk(f::Function, x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
 [25] fmap(f::typeof(_init_weights!), x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [26] (::Functors.var"#18#19"{typeof(is_layer_or_leaf), typeof(Functors._default_walk), IdDict{Any, Any}, Functors.NoKeyword, typeof(_init_weights!)})(x::Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}})
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [27] map
    @ ./tuple.jl:273 [inlined]
 [28] map(::Function, ::NamedTuple{(:layers,), Tuple{Chain{Tuple{Chain{Vector{Any}}, Chain{Tuple{AdaptiveMeanPool{4, 2}, typeof(MLUtils.flatten), Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}}})
    @ Base ./namedtuple.jl:218
 [29] _default_walk(f::Function, x::DenseNet)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:43
 [30] fmap(f::typeof(_init_weights!), x::DenseNet; exclude::typeof(is_layer_or_leaf), walk::typeof(Functors._default_walk), cache::IdDict{Any, Any}, prune::Functors.NoKeyword)
    @ Functors ~/.julia/packages/Functors/qBIlC/src/functor.jl:50
 [31] top-level scope
    @ REPL[16]:1
 [32] top-level scope
    @ ~/.julia/packages/CUDA/GGwVa/src/initialization.jl:52

Am I missing something here? Why isn't this working the way it's supposed to?

theabhirath avatar Jun 05 '22 16:06 theabhirath

Most likely you are trying to accumulate into a bias=false field. Bool (if not all non-array) params are probably safe to ignore when re-initializing, but at some point (probably out of scope for now) we'd want to consider immutable arrays like SArray as well. Those would require returning an updated layer from _init_weights! much like Optimisers.update! does now with state + gradients.

ToucheSir avatar Jun 05 '22 17:06 ToucheSir