Optimisers.jl
Optimisers.jl copied to clipboard
Using `adjust!` on weight decay (L2) and sign decay (L1) at the same time?
Motivation and description
In other contexts, combining L1 and L2 regularization can be reasonable. In Optimisers, they have the same parameter name, which, if I understand correctly, will mean that adjust will change both?
Possible Implementation
No response
Similarly, what if I wanted to use SignDecay with AdamW, so I set AdamW's lambda to 0. Would trying to adjust the SignDecay lambda cause AdamW's lambda to then be non-zero?
Yes it will change both:
julia> os = Optimisers.setup(OptimiserChain(SignDecay(0.1), AdamW(lambda=0.1)), (; x=[12 34.]))
(x = Leaf(OptimiserChain(SignDecay(0.1), AdamW(0.001, (0.9, 0.999), 0.1, 1.0e-8, true)), (nothing, ([0.0 0.0], [0.0 0.0], (0.9, 0.999)))),)
julia> Optimisers.adjust!(os, lambda=0.3)
julia> os
(x = Leaf(OptimiserChain(SignDecay(0.3), AdamW(0.001, (0.9, 0.999), 0.3, 1.0e-8, true)), (nothing, ([0.0 0.0], [0.0 0.0], (0.9, 0.999)))),)
That's a reason to give them all different names, e.g. lambda1 for SignDecay and lambda2 for WeightDecay.
But it's a breaking change, as people may be using adjust already. Unfortunately can't now be combined with the change for AdamW.
How about adding an adjust method that lets you specify the type. Like adjust!(state, (SignDecay, lambda = 0.05)) and only adjusts optimisers that match the type?
All things are possible but I don't think we should add further complexity to adjust!'s interface.
The less-breaking way is to change all 3, and then overload the implementation of adjust! for just these types to still accept lambda. Although if anyone has saved a state tree, then that will still be broken.
Note that because of an earlier re-naming (maybe from what Flux called things, maybe to match the AdamW convention, when AdamW made a chain with WeightDecay) you can in fact change L1 and L2 parameters independently:
julia> os2 = Optimisers.setup(OptimiserChain(SignDecay(lambda=0.1), WeightDecay(lambda=0.1)), (; x=[12 34.]))
(x = Leaf(OptimiserChain(SignDecay(0.1), WeightDecay(0.1)), (nothing, nothing)),)
julia> Optimisers.adjust!(os2, gamma=0.3)
┌ Warning: The strength of WeightDecay is now field :lambda, not :gamma
│ caller = #111 at rules.jl:800 [inlined]
└ @ Core ~/.julia/packages/Optimisers/lLmiA/src/rules.jl:800
julia> os2
(x = Leaf(OptimiserChain(SignDecay(0.1), WeightDecay(0.3)), (nothing, nothing)),)