equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Truncated backprop using filtered-transformations

Open SNMS95 opened this issue 10 months ago • 3 comments

Hi everyone,

I am working on doing meta-learning and wanted to implement truncated backprop to estimate the meta-level gradients.

import jax.flatten_util
import optax
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp

NSTEPS = 500
KIND = 'lax'


def inner_loss(params):
    """Loss function"""
    flat_params = jax.flatten_util.ravel_pytree(params)[0]
    return jnp.sum(flat_params**2)


@eqx.filter_jit
def inner_optimization(params, opt_state, num_steps):

    def inner_opt_step(carry, _):
        """Inner optimization function"""
        params, opt_state = carry
        # Get the gradients
        inner_grads = eqx.filter_grad(inner_loss)(params)
        # Update the parameters
        updates, opt_state = inner_opt.update(inner_grads, opt_state)
        params = eqx.apply_updates(params, updates)
        return (params, opt_state), None

    # make a scan
    init = (params, opt_state)
    (final_params, final_opt_state), _ = eqxi.scan(inner_opt_step, init,
                                                   None, length=num_steps, kind=KIND)
    return final_params, final_opt_state

def truncated_inner_optimization(params, opt_state, num_steps, num_steps_truncated):
    """Truncated inner optimization"""

    @eqx.filter_custom_jvp
    def wrapped_inner_optimization(params, opt_state, num_steps):
        final_params, final_opt_state = inner_optimization(
            params, opt_state, num_steps)
        return final_params, final_opt_state

    @wrapped_inner_optimization.def_jvp
    def wrapped_inner_optimization_jvp(primals, tangents):
        """Truncated inner optimization"""
        primal_out = wrapped_inner_optimization(*primals)
        tangent_out = tangents[:2]
        return primal_out, tangent_out

    # Run the inner optimization with wrapped fn
    final_params, final_opt_state = eqx.filter_jit(
        wrapped_inner_optimization)(params, opt_state, num_steps_truncated)

    # Run the remaining steps
    final_params, final_opt_state = inner_optimization(
        final_params, final_opt_state, num_steps - num_steps_truncated)
    return final_params, final_opt_state

def outer_loss2(params):
    """Outer loss function"""
    # Run optimization
    final_params, _ = truncated_inner_optimization(
        params, inner_opt_state, NSTEPS, NSTEPS - 100)
    # Compute the loss
    flat_final_params = jax.flatten_util.ravel_pytree(final_params)[0]
    return jnp.sum(flat_final_params**3)

# Test
# create a pytree for the parameters 
params = {'w': jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))}
inner_opt = optax.adam(1e-3)
inner_opt_state = inner_opt.init(params)
outer_grads = eqx.filter_grad(outer_loss2)(params)
print(outer_grads)

When I checked the memory usage, it seems that this method is working but it feels hacky. Is there a better way to do this?

SNMS95 avatar Feb 21 '25 15:02 SNMS95

This looks pretty good to me! It's probably worth tweaking things slightly to have a single JIT wrapping the whole thing (including your final outer_grads = eqx.filter_grad(...), but otherwise I think this looks about as good as it gets :)

patrick-kidger avatar Feb 21 '25 17:02 patrick-kidger

Hi @patrick-kidger

This is an even more minimal example. But this errors out.

import jax
import jax.flatten_util
import jax.numpy as jnp
import equinox as eqx
from collections import namedtuple

class Container(eqx.Module):
    inner_model: eqx.Module
    outer_model: namedtuple

    def __init__(self, outer_param_dict):
        self.inner_model = eqx.nn.MLP(1, 1, 5, 3, key=jax.random.PRNGKey(0))
        # Convert dictionary to namedtuple
        self.outer_model = namedtuple('OuterModel', outer_param_dict.keys())(**outer_param_dict)

    def __call__(self, x):
        return self.inner_model(x)
    
def new_inner_model(model):
    """Update the inner model"""
    inner_params, inner_static = eqx.partition(model, inner_filter)
    val = eqx.combine(inner_params, inner_static)(jnp.array([1.0]))
    inner_params = jax.tree.map(lambda x: x*val, inner_params)
    return eqx.combine(inner_params, inner_static)

def wrapper(fn):
    """Wrap to do truncated backprop"""

    @eqx.filter_custom_jvp
    def wrapped_fn(*args):
        return fn(*args)
    
    @wrapped_fn.def_jvp
    def _jvp(primals, tangents):
        primals_out = wrapped_fn(*primals)
        return primals_out, tangents

    return wrapped_fn

new_inner_model_wrapped = wrapper(new_inner_model)

def loss(outer_params, outer_static):
    a = outer_params.outer_model.a
    b = outer_params.outer_model.b
    # Do something to the inner params
    model = eqx.combine(outer_params, outer_static)
    inner_param, inner_static = eqx.partition(model, inner_filter)
    inner_param = jax.tree.map(lambda x: x*(a + b**2), inner_param)
    model = eqx.combine(inner_param, inner_static)

    # model = new_inner_model(model)  # This works
    model = new_inner_model_wrapped(model)  # This doesn't work

    inner_param, inner_static = eqx.partition(inner_param, inner_filter)
    flat_inner_params = jax.flatten_util.ravel_pytree(inner_param)[0]
    return a + b  + jnp.sum(flat_inner_params) + a + b

# test
import jax.tree as jt
outer_param_dict = {'a': jnp.array(1.0), 'b': jnp.array(2.0)}
container = Container(outer_param_dict)

base_filter = jt.map(lambda _: False, container)
def _bilevel_igor_filters(model):
    task_filter = jt.map(eqx.is_array, model.inner_model)
    task_filter = eqx.tree_at(
        lambda tree: tree.inner_model, base_filter, task_filter)
    meta_filter = jt.map(eqx.is_array_like, model.outer_model)
    meta_filter = eqx.tree_at(
        lambda tree: tree.outer_model, base_filter, meta_filter)
    return task_filter, meta_filter

inner_filter, outer_filter = _bilevel_igor_filters(container)

outer_params, outer_static = eqx.partition(container, outer_filter)
eqx.filter_grad(loss)(outer_params, outer_static).outer_model

Do you know what should be done to include this?

SNMS95 avatar Feb 25 '25 16:02 SNMS95

The output of a eqx.filter_custom_jvp must still consist only of JAX types. In this case you're returning the output of new_inner_model, which still has its static componets.

(I think)

patrick-kidger avatar Feb 28 '25 18:02 patrick-kidger