optax icon indicating copy to clipboard operation
optax copied to clipboard

Correct usage of look-ahead optimizer

Open SNMS95 opened this issue 2 months ago • 3 comments

I want to use the look-ahead optimizer but it does not seem to fit the optax pattern exactly.

import optax
import jax
import jax.numpy as jnp

def fn_to_optimize(x):
    x = x.fast
    return jnp.sum((x) ** 2)

params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
# params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)

for step in range(100):
    loss, grads = jax.value_and_grad(fn_to_optimize)(params)
    updates, state = solver.update(grads, state, params)
    params = optax.apply_updates(params, updates)
    if step % 10 == 0:
        print(f"Step {step}, Loss: {loss}, Params: {params}")

I would have expected it to be a drop in replacement like the other optimizers but that'd does not seem to be the case.

When I look at the source code (https://github.com/google-deepmind/optax/blob/5bd909532d9814667c188ac09b675183118e76eb/optax/_src/lookahead.py#L99), the init function is very clear. However, the update function is confusing. It expects params:LookaheadParams unlike init

This seems to be right way to do it but this is not very intuitive.

import optax
import jax
import jax.numpy as jnp

def fn_to_optimize(x):
    return jnp.sum((x) ** 2)

params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)

for step in range(100):
    loss, grads = jax.value_and_grad(fn_to_optimize)(params.fast)
    updates, state = solver.update(grads, state, params)
    params = optax.apply_updates(params, updates)
    if step % 1 == 0:
        print(f"Step {step}, Loss: {loss}, Params: {params}")

SNMS95 avatar Oct 07 '25 19:10 SNMS95

Did you take a look at https://optax.readthedocs.io/en/latest/_collections/examples/lookahead_mnist.html?

vroulet avatar Oct 10 '25 23:10 vroulet

Hey I did. However, in the original paper, the algorithm is bilevel and the slow parameters are the "main" parameters. In the example, the slow parameters are only used during testing. Further, there is only a one way coupling i..e the fast weights affect the slow ones but not the other way around. So, in fact, this inequivalent to just running the fast optimizer for training, while keeping track of the slow parameters for inference.

From the (paper) , this is also a method to make the fast optimizer more robust to hyper-parameter choices.

SNMS95 avatar Oct 11 '25 08:10 SNMS95

@mkunesch could you take a look at this? You were the author of that optimizer. Your help is greatly appreciated :)

vroulet avatar Nov 14 '25 18:11 vroulet