optax icon indicating copy to clipboard operation
optax copied to clipboard

[Feature Request] Normalized gradient descent

Open smorad opened this issue 10 months ago • 6 comments

Optax has various clipping operators, but as far as I can tell, it cannot scale by gradient norm. Adding these capabilities such that they could be chained would allow us to use normalized gradient descent methods (e.g. normalized Adam, etc).

A simple implementation might look like

def scale_by_norm(scale: float=1.0, eps: float=1e-6):
  def init_fn(params):
    del params
    return optax._src.base.OptState

  def update_fn(updates, state, params=None):
    del params
    g_norm = jnp.maximum(optax.global_norm(gradient) + eps, scale)
    def scale_fn(t):
       return t / g_norm

    updates = jax.tree_util.tree_map(scale_fn, updates)
    return updates, state

  return optax.GradientTransformation(init_fn, update_fn)

smorad avatar Oct 09 '23 15:10 smorad