Correct usage of look-ahead optimizer
I want to use the look-ahead optimizer but it does not seem to fit the optax pattern exactly.
import optax
import jax
import jax.numpy as jnp
def fn_to_optimize(x):
x = x.fast
return jnp.sum((x) ** 2)
params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
# params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)
for step in range(100):
loss, grads = jax.value_and_grad(fn_to_optimize)(params)
updates, state = solver.update(grads, state, params)
params = optax.apply_updates(params, updates)
if step % 10 == 0:
print(f"Step {step}, Loss: {loss}, Params: {params}")
I would have expected it to be a drop in replacement like the other optimizers but that'd does not seem to be the case.
When I look at the source code (https://github.com/google-deepmind/optax/blob/5bd909532d9814667c188ac09b675183118e76eb/optax/_src/lookahead.py#L99), the init function is very clear.
However, the update function is confusing.
It expects params:LookaheadParams unlike init
This seems to be right way to do it but this is not very intuitive.
import optax
import jax
import jax.numpy as jnp
def fn_to_optimize(x):
return jnp.sum((x) ** 2)
params = jnp.array([2.0, 2.0])
fast_optimizer = optax.adam(1e-1)
solver = optax.lookahead(fast_optimizer, sync_period=5, slow_step_size=0.5)
params = optax.LookaheadParams.init_synced(params)
state = solver.init(params)
for step in range(100):
loss, grads = jax.value_and_grad(fn_to_optimize)(params.fast)
updates, state = solver.update(grads, state, params)
params = optax.apply_updates(params, updates)
if step % 1 == 0:
print(f"Step {step}, Loss: {loss}, Params: {params}")
Did you take a look at https://optax.readthedocs.io/en/latest/_collections/examples/lookahead_mnist.html?
Hey I did. However, in the original paper, the algorithm is bilevel and the slow parameters are the "main" parameters. In the example, the slow parameters are only used during testing. Further, there is only a one way coupling i..e the fast weights affect the slow ones but not the other way around. So, in fact, this inequivalent to just running the fast optimizer for training, while keeping track of the slow parameters for inference.
From the (paper) , this is also a method to make the fast optimizer more robust to hyper-parameter choices.
@mkunesch could you take a look at this? You were the author of that optimizer. Your help is greatly appreciated :)