optimistix icon indicating copy to clipboard operation
optimistix copied to clipboard

Zero implicit gradients when using `ImplicitAdjoint` with CG solver

Open itk22 opened this issue 3 months ago • 4 comments

Hi @patrick-kidger and @packquickly,

I was trying to implement the following meta-learning example from jax-opt in optimistix: Few-shot Adaptation with Model Agnostic Meta-Learning . However, I ran into an issue with implicit differentiation through the inner loop. The below example runs well when using optx.RecursiveCheckpointAdjoint but when I try to recreate the iMAML setup by putting optx.ImplicitAdjoint with a CG solver with 20 steps, all the meta-gradients are zero, and the meta-optimiser doesn't change at all in the training. Could you please help me identify the issue with the code? It seems to be an implementation detail for implicit adjoints that differs between jax-opt and optimistic.

Here is an MWE:

import optimistix as optx
import equinox as eqx
import lineax as lx
import jax
import jax.random as jr
import jax.numpy as jnp
import optax

key = jr.PRNGKey(0)
model = eqx.nn.MLP(1, 1, 40, 2, key=key)

sine_target = lambda x: 1.0 * jnp.sin(x - 0.5) # Target function
x = jr.normal(key, (10, 1)) # Randomly drawn inputs for validation
y_true = sine_target(x)

opt = optx.OptaxMinimiser(optax.adam(1e-3, eps_root=1e-8), 1e-7, 1e-7)
params, static = eqx.partition(model, eqx.is_inexact_array)

def apply_model(params, x):
    model = eqx.combine(params, static)
    return jax.vmap(model)(x)

def loss_fn(params, args):
    y_pred = apply_model(params, x)
    loss = jnp.mean(jnp.square(y_pred - y_true))
    return loss, loss

def adapt_fn(params):
    sol = optx.minimise(loss_fn,
                        opt,
                        params,
                        None,
                        has_aux=True,
                        max_steps=2,
                        throw=False,
                        adjoint=optx.ImplicitAdjoint(lx.CG(1e-7, 1e-7, max_steps=10)),
                        tags=lx.positive_semidefinite_tag)
    return sol.aux # Return the final loss only

loss, grad = jax.value_and_grad(adapt_fn)(params)

print(f"Final loss: {loss:.5f}")
print(f"Gradient: {grad.layers[0].weight}")

itk22 avatar Mar 28 '24 13:03 itk22