optax icon indicating copy to clipboard operation
optax copied to clipboard

[Feature Request] Normalized gradient descent

Open smorad opened this issue 1 year ago • 6 comments

Optax has various clipping operators, but as far as I can tell, it cannot scale by gradient norm. Adding these capabilities such that they could be chained would allow us to use normalized gradient descent methods (e.g. normalized Adam, etc).

A simple implementation might look like

def scale_by_norm(scale: float=1.0, eps: float=1e-6):
  def init_fn(params):
    del params
    return optax._src.base.OptState

  def update_fn(updates, state, params=None):
    del params
    g_norm = jnp.maximum(optax.global_norm(gradient) + eps, scale)
    def scale_fn(t):
       return t / g_norm

    updates = jax.tree_util.tree_map(scale_fn, updates)
    return updates, state

  return optax.GradientTransformation(init_fn, update_fn)

smorad avatar Oct 09 '23 15:10 smorad

Do you have a reference to this specific way of normalising?

mtthss avatar Oct 10 '23 07:10 mtthss

This textbook describes it fairly well. My example might be a little fancy, but you could replace the maximum with

g_norm = (optax.global_norm(gradient) + eps) / scale

In this case, scale would refer to alpha in Eq 6.

smorad avatar Oct 10 '23 09:10 smorad

Sounds like it could be a good addition. Do you want to put together a PR?

mtthss avatar Oct 10 '23 10:10 mtthss

Seems like a simple extension of

https://github.com/google-deepmind/optax/blob/841be5a860bdf271c0b4ee4b757710bd9497537d/optax/contrib/sam.py#L63-L80

@mtthss can I take this up ?

SauravMaheshkar avatar Nov 25 '23 20:11 SauravMaheshkar

I think this might actually be implemented in clip_by_global_norm. IIRC the code there actually scales the gradient rather than clips it. Might be worth double checking before starting.

smorad avatar Nov 25 '23 21:11 smorad

clip_by_global_norm clips but do not necessarily normalize (if the updates are less than clip norm, then they are just returned as is). In other words clip projects on a ball and @smorad you want to project on a sphere. I think @SauravMaheshkar pointed out a good starting point.

vroulet avatar Feb 05 '24 12:02 vroulet

Fixed in #958

CC: @fabianp

SauravMaheshkar avatar Jul 09 '24 10:07 SauravMaheshkar