Lux.jl
Lux.jl copied to clipboard
How to freeze layers?
Thanks for the awesome project! I really enjoyed your talk too.
Flux.jl has option for trainable as well as deleting things from params Certain tasks may have some layers frozen for a few epochs and trainable later or vice versa
What is the recommended way in these scenarios?
- Freeze layers of a particular types say Conv2 (The way we do this in PyTorch is going through the params and filtering by Type)
- Freeze a particular layer by index/name
- Freeze only parts of a layer say first w of W in Dense
- Freeze Series of layers say all except last 2 layers
Thanks!
The long-term goal is to offload most of these to Optimisers (See https://github.com/FluxML/Optimisers.jl/pull/49). Thanks for opening an issue regarding this. Currently, there isn't a very convenient way to do this. I will probably write this up in the manual in the upcoming weeks, but here is a short version (without code).
Before anything, we will need a FrozenLayer
, which effectively contains a layer
and stores any of the frozen parameters of the layer
in its state. This way, those parameters will not be updated. Now coming to the scenarios:
- This one needs the Lux layers to support the
fmap
API, which is simple, but I just haven't had the time and necessity to sit down and write it. Once done, it will look likefmap(x -> x isa LayerType ? freeze(x) : x, model)
- This is slightly tricky but possible with the same API.
fmap
allows you to specify a check for leaf nodes. In this case, a leaf would not be something with no children, rather, we need to check the names and not descend if there is a name match. This might sound a bit finicky if this is not implemented with a recursive routine. - Again same as 1, instead of
freeze(x)
writefreeze(x, (:weight,)
- Just wrap the entire thing in
freeze
. So if you haveChain(layer_1, ..., layer_(n-1), layer_n)
doChain(freeze(Chain(layer_1, ..., layer_(n - 2)), last_2_layers)
Now coming to why we don't have FrozenLayer
right now. As you might have guessed, for initializing this layer we need to access the underlying layers' parameters during state construction. This means we need to do the same initialparameters
operation twice, which is not the end of the world and probably the v0.4.* implementation will just call it twice. Fixing this will actually need a breaking change to merge the initialparameters
and initialstates
functions into one.