optax icon indicating copy to clipboard operation
optax copied to clipboard

Add weight_decay and mask arguments to adabelief optimizer.

Open carlosgmartin opened this issue 7 months ago • 6 comments

Fixes #1290.

carlosgmartin avatar May 01 '25 22:05 carlosgmartin

Could we move towards calling the weight decay mask, weight_decay_mask instead of mask? We should probably move to this convention since mask is somewhat ambiguous in a high level optimizer interface.

If adabelief needs weight decay then perhaps we can think of adding it to all the main optimizers in alias.py? Wdyt? @carlosgmartin

rdyro avatar May 05 '25 16:05 rdyro

@rdyro I've changed the argument's name from mark to weight_decay_mask.

I'll leave the changing the other optimizers' argument names to a subsequent PR, to keep this one self-contained.

carlosgmartin avatar May 06 '25 19:05 carlosgmartin

@rdyro Does this look good?

carlosgmartin avatar Jun 06 '25 23:06 carlosgmartin

I'm not entirely sure about this change, the original adabelief paper explicitly discusses, but does not use weight decay.

The problem for optax is that weight decay is NOT scaled by the learning rate, so the user has two options for adding weight decay to an existing optimizer:

  • reimplement the optimizer chain to insert the weight decay before scale_by_learning_rate in the chain
  • chain the pre-made optimizers (e.g., adabelief) with another chain of (weight_decay, scale_by_learning_rate)

It'd be great if we can solve this problem more systematically to not have to add extra weight decay arguments to every popular optimizer.

Perhaps we can introduce another keyword argument to the add_decayed_weights which takes in the learning rate (schedule)? @carlosgmartin

For a systematic fix, I'd prefer to remove the additional weight_decay keyword argument from pre-made optimizers (but we should keep the ones that explicitly include them (e.g., adamw) and ones to which we added the weight decay kwarg for backward compatibility).

rdyro avatar Jun 06 '25 23:06 rdyro

What does @vroulet think?

carlosgmartin avatar Jun 09 '25 20:06 carlosgmartin

The repository of the original author seems to have some weight decay https://github.com/juntang-zhuang/Adabelief-Optimizer/tree/update_0.2.0. So having a weight decay implementation makes sense.

I agree with Robert that the current duplications of weight_decay arguments are pretty bad (in particular the documentation is quite heavy, it would be best to have a "see_also" for people to know how to add weight decay). I like the idea of maybe adding a keyword argument to add_weight_decay (it may lead to a relatively large factorization though).

vroulet avatar Jun 17 '25 20:06 vroulet