optax icon indicating copy to clipboard operation
optax copied to clipboard

Add Optimistic Adam

Open carlosgmartin opened this issue 1 year ago • 3 comments

Feature request: Add Optimistic Adam, an optimistic variant of Adam introduced in [1]. Among other things, it addresses the issue of limit cycling behavior in GAN training.

Perhaps it can be implemented by combining scale_by_adam with scale_by_optimistic_gradient using chain.

References:

  1. Constantinos Daskalakis, Andrew Ilyas, Vasilis Syrgkanis, Haoyang Zeng. Training GANs with Optimism. ICLR 2018. OpenReview. ArXiv.

carlosgmartin avatar Oct 02 '24 00:10 carlosgmartin

Below is a demonstration:

import argparse

import jax
import optax
from jax import lax, numpy as jnp
from matplotlib import pyplot as plt, rcParams


def optimistic_sgd(learning_rate, strength):
    return optax.scale_by_optimistic_gradient(-learning_rate, -strength)


def optimistic_adam(learning_rate, strength):
    return optax.chain(
        optax.scale_by_adam(),
        optax.scale_by_optimistic_gradient(-learning_rate, -strength),
    )


def optimistic_adam_wrong_order(learning_rate, strength):
    return optax.chain(
        optax.scale_by_optimistic_gradient(-learning_rate, -strength),
        optax.scale_by_adam(),
    )


def bilinear_utility_fn(params):
    """Bilinear saddle point.
    Has a unique Nash equilibrium at the origin."""
    x, y = params
    z = x * y
    return jnp.stack([z, -z])


def dirac_gan_utility_fn(params):
    """Dirac GAN: https://arxiv.org/abs/1801.04406.
    Has a unique Nash equilibrium at the origin."""
    x, y = params
    z = jnp.logaddexp(0, x * y)
    return jnp.stack([z, -z])


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--game", type=str, default="bilinear")
    p.add_argument("--lr", type=float, default=1e-3)
    p.add_argument("--iters", type=int, default=10**5)
    p.add_argument("--strength", type=float, default=1e-1)
    return p.parse_args()


def main():
    args = parse_args()

    match args.game:
        case "bilinear":
            utility_fn = bilinear_utility_fn
        case "dirac_gan":
            utility_fn = dirac_gan_utility_fn
        case _:
            raise NotImplementedError(args.game)

    def update(state, _):
        params, opt_state = state
        jac = jax.jacobian(utility_fn)(params)
        grads = jax.tree.map(jnp.diag, jac)
        updates, opt_state = opt.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        return (params, opt_state), params

    _, ax_distances = plt.subplots()
    _, ax_params = plt.subplots()

    params = jnp.array([1.0, 2.0])
    for label, opt in [
        ("SGD", optax.sgd(args.lr)),
        ("Adam", optax.adam(args.lr)),
        ("Optimistic SGD", optimistic_sgd(args.lr, args.strength)),
        ("Optimistic Adam", optimistic_adam(args.lr, args.strength)),
    ]:
        opt_state = opt.init(params)
        _, params_hist = lax.scan(
            update, (params, opt_state), length=args.iters
        )
        distances_to_origin = jnp.hypot(*params_hist.T)
        ax_params.plot(*params_hist.T, label=label, lw=1)
        ax_distances.plot(distances_to_origin, label=label, lw=1)

    ax_params.legend()
    ax_distances.legend()
    ax_params.set(title="parameters")
    ax_distances.set(xlabel="iteration", ylabel="distance to origin")
    rcParams["savefig.dpi"] = 300
    plt.show()


if __name__ == "__main__":
    main()

Outputs for --game=bilinear:

Outputs for --game=dirac_gan:

I can submit a PR to create an optax.optimistic_adam function.

carlosgmartin avatar Oct 02 '24 01:10 carlosgmartin

this is great @carlosgmartin !

Would you be willing to contribute such example to the example gallery (https://optax.readthedocs.io/en/latest/gallery.html)? I think this would be very valuable even if there's the somewhat related https://optax.readthedocs.io/en/latest/_collections/examples/ogda_example.html , but I think both examples could be complementary. What do you think?

I would also be OK with adding the solver optimistic_adam to optax (although that would require a bit of work on docstring + tests for this solver)

fabianp avatar Oct 02 '24 08:10 fabianp

@fabianp Done: #1089.

carlosgmartin avatar Oct 03 '24 23:10 carlosgmartin