Lux.jl icon indicating copy to clipboard operation
Lux.jl copied to clipboard

How to freeze layers?

Open reachtarunhere opened this issue 2 years ago • 1 comments

Thanks for the awesome project! I really enjoyed your talk too.

Flux.jl has option for trainable as well as deleting things from params Certain tasks may have some layers frozen for a few epochs and trainable later or vice versa

What is the recommended way in these scenarios?

  1. Freeze layers of a particular types say Conv2 (The way we do this in PyTorch is going through the params and filtering by Type)
  2. Freeze a particular layer by index/name
  3. Freeze only parts of a layer say first w of W in Dense
  4. Freeze Series of layers say all except last 2 layers

Thanks!

reachtarunhere avatar Jul 29 '22 08:07 reachtarunhere

The long-term goal is to offload most of these to Optimisers (See https://github.com/FluxML/Optimisers.jl/pull/49). Thanks for opening an issue regarding this. Currently, there isn't a very convenient way to do this. I will probably write this up in the manual in the upcoming weeks, but here is a short version (without code).

Before anything, we will need a FrozenLayer, which effectively contains a layer and stores any of the frozen parameters of the layer in its state. This way, those parameters will not be updated. Now coming to the scenarios:

  1. This one needs the Lux layers to support the fmap API, which is simple, but I just haven't had the time and necessity to sit down and write it. Once done, it will look like fmap(x -> x isa LayerType ? freeze(x) : x, model)
  2. This is slightly tricky but possible with the same API. fmap allows you to specify a check for leaf nodes. In this case, a leaf would not be something with no children, rather, we need to check the names and not descend if there is a name match. This might sound a bit finicky if this is not implemented with a recursive routine.
  3. Again same as 1, instead of freeze(x) write freeze(x, (:weight,)
  4. Just wrap the entire thing in freeze. So if you have Chain(layer_1, ..., layer_(n-1), layer_n) do Chain(freeze(Chain(layer_1, ..., layer_(n - 2)), last_2_layers)

Now coming to why we don't have FrozenLayer right now. As you might have guessed, for initializing this layer we need to access the underlying layers' parameters during state construction. This means we need to do the same initialparameters operation twice, which is not the end of the world and probably the v0.4.* implementation will just call it twice. Fixing this will actually need a breaking change to merge the initialparameters and initialstates functions into one.

avik-pal avatar Jul 30 '22 08:07 avik-pal