optax
optax copied to clipboard
[Feature Request] Normalized gradient descent
Optax has various clipping operators, but as far as I can tell, it cannot scale by gradient norm. Adding these capabilities such that they could be chained would allow us to use normalized gradient descent methods (e.g. normalized Adam, etc).
A simple implementation might look like
def scale_by_norm(scale: float=1.0, eps: float=1e-6):
def init_fn(params):
del params
return optax._src.base.OptState
def update_fn(updates, state, params=None):
del params
g_norm = jnp.maximum(optax.global_norm(gradient) + eps, scale)
def scale_fn(t):
return t / g_norm
updates = jax.tree_util.tree_map(scale_fn, updates)
return updates, state
return optax.GradientTransformation(init_fn, update_fn)
Do you have a reference to this specific way of normalising?
This textbook describes it fairly well. My example might be a little fancy, but you could replace the maximum with
g_norm = (optax.global_norm(gradient) + eps) / scale
In this case, scale
would refer to alpha in Eq 6.
Sounds like it could be a good addition. Do you want to put together a PR?
Seems like a simple extension of
https://github.com/google-deepmind/optax/blob/841be5a860bdf271c0b4ee4b757710bd9497537d/optax/contrib/sam.py#L63-L80
@mtthss can I take this up ?
I think this might actually be implemented in clip_by_global_norm
. IIRC the code there actually scales the gradient rather than clips it. Might be worth double checking before starting.
clip_by_global_norm
clips but do not necessarily normalize (if the updates are less than clip norm, then they are just returned as is). In other words clip projects on a ball and @smorad you want to project on a sphere.
I think @SauravMaheshkar pointed out a good starting point.
Fixed in #958
CC: @fabianp