optax
optax copied to clipboard
Add Optimistic Adam
Feature request: Add Optimistic Adam, an optimistic variant of Adam introduced in [1]. Among other things, it addresses the issue of limit cycling behavior in GAN training.
Perhaps it can be implemented by combining scale_by_adam with scale_by_optimistic_gradient using chain.
References:
- Constantinos Daskalakis, Andrew Ilyas, Vasilis Syrgkanis, Haoyang Zeng. Training GANs with Optimism. ICLR 2018. OpenReview. ArXiv.
Below is a demonstration:
import argparse
import jax
import optax
from jax import lax, numpy as jnp
from matplotlib import pyplot as plt, rcParams
def optimistic_sgd(learning_rate, strength):
return optax.scale_by_optimistic_gradient(-learning_rate, -strength)
def optimistic_adam(learning_rate, strength):
return optax.chain(
optax.scale_by_adam(),
optax.scale_by_optimistic_gradient(-learning_rate, -strength),
)
def optimistic_adam_wrong_order(learning_rate, strength):
return optax.chain(
optax.scale_by_optimistic_gradient(-learning_rate, -strength),
optax.scale_by_adam(),
)
def bilinear_utility_fn(params):
"""Bilinear saddle point.
Has a unique Nash equilibrium at the origin."""
x, y = params
z = x * y
return jnp.stack([z, -z])
def dirac_gan_utility_fn(params):
"""Dirac GAN: https://arxiv.org/abs/1801.04406.
Has a unique Nash equilibrium at the origin."""
x, y = params
z = jnp.logaddexp(0, x * y)
return jnp.stack([z, -z])
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--game", type=str, default="bilinear")
p.add_argument("--lr", type=float, default=1e-3)
p.add_argument("--iters", type=int, default=10**5)
p.add_argument("--strength", type=float, default=1e-1)
return p.parse_args()
def main():
args = parse_args()
match args.game:
case "bilinear":
utility_fn = bilinear_utility_fn
case "dirac_gan":
utility_fn = dirac_gan_utility_fn
case _:
raise NotImplementedError(args.game)
def update(state, _):
params, opt_state = state
jac = jax.jacobian(utility_fn)(params)
grads = jax.tree.map(jnp.diag, jac)
updates, opt_state = opt.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return (params, opt_state), params
_, ax_distances = plt.subplots()
_, ax_params = plt.subplots()
params = jnp.array([1.0, 2.0])
for label, opt in [
("SGD", optax.sgd(args.lr)),
("Adam", optax.adam(args.lr)),
("Optimistic SGD", optimistic_sgd(args.lr, args.strength)),
("Optimistic Adam", optimistic_adam(args.lr, args.strength)),
]:
opt_state = opt.init(params)
_, params_hist = lax.scan(
update, (params, opt_state), length=args.iters
)
distances_to_origin = jnp.hypot(*params_hist.T)
ax_params.plot(*params_hist.T, label=label, lw=1)
ax_distances.plot(distances_to_origin, label=label, lw=1)
ax_params.legend()
ax_distances.legend()
ax_params.set(title="parameters")
ax_distances.set(xlabel="iteration", ylabel="distance to origin")
rcParams["savefig.dpi"] = 300
plt.show()
if __name__ == "__main__":
main()
Outputs for --game=bilinear:
Outputs for --game=dirac_gan:
I can submit a PR to create an optax.optimistic_adam function.
this is great @carlosgmartin !
Would you be willing to contribute such example to the example gallery (https://optax.readthedocs.io/en/latest/gallery.html)? I think this would be very valuable even if there's the somewhat related https://optax.readthedocs.io/en/latest/_collections/examples/ogda_example.html , but I think both examples could be complementary. What do you think?
I would also be OK with adding the solver optimistic_adam to optax (although that would require a bit of work on docstring + tests for this solver)
@fabianp Done: #1089.