optax
optax copied to clipboard
Add weight_decay and mask arguments to adabelief optimizer.
Fixes #1290.
Could we move towards calling the weight decay mask, weight_decay_mask instead of mask? We should probably move to this convention since mask is somewhat ambiguous in a high level optimizer interface.
If adabelief needs weight decay then perhaps we can think of adding it to all the main optimizers in alias.py? Wdyt? @carlosgmartin
@rdyro I've changed the argument's name from mark to weight_decay_mask.
I'll leave the changing the other optimizers' argument names to a subsequent PR, to keep this one self-contained.
@rdyro Does this look good?
I'm not entirely sure about this change, the original adabelief paper explicitly discusses, but does not use weight decay.
The problem for optax is that weight decay is NOT scaled by the learning rate, so the user has two options for adding weight decay to an existing optimizer:
- reimplement the optimizer chain to insert the weight decay before
scale_by_learning_ratein the chain - chain the pre-made optimizers (e.g.,
adabelief) with another chain of(weight_decay, scale_by_learning_rate)
It'd be great if we can solve this problem more systematically to not have to add extra weight decay arguments to every popular optimizer.
Perhaps we can introduce another keyword argument to the add_decayed_weights which takes in the learning rate (schedule)? @carlosgmartin
For a systematic fix, I'd prefer to remove the additional weight_decay keyword argument from pre-made optimizers (but we should keep the ones that explicitly include them (e.g., adamw) and ones to which we added the weight decay kwarg for backward compatibility).
What does @vroulet think?
The repository of the original author seems to have some weight decay https://github.com/juntang-zhuang/Adabelief-Optimizer/tree/update_0.2.0. So having a weight decay implementation makes sense.
I agree with Robert that the current duplications of weight_decay arguments are pretty bad (in particular the documentation is quite heavy, it would be best to have a "see_also" for people to know how to add weight decay). I like the idea of maybe adding a keyword argument to add_weight_decay (it may lead to a relatively large factorization though).