Truncated backprop using filtered-transformations
Hi everyone,
I am working on doing meta-learning and wanted to implement truncated backprop to estimate the meta-level gradients.
import jax.flatten_util
import optax
import equinox as eqx
import equinox.internal as eqxi
import jax
import jax.numpy as jnp
NSTEPS = 500
KIND = 'lax'
def inner_loss(params):
"""Loss function"""
flat_params = jax.flatten_util.ravel_pytree(params)[0]
return jnp.sum(flat_params**2)
@eqx.filter_jit
def inner_optimization(params, opt_state, num_steps):
def inner_opt_step(carry, _):
"""Inner optimization function"""
params, opt_state = carry
# Get the gradients
inner_grads = eqx.filter_grad(inner_loss)(params)
# Update the parameters
updates, opt_state = inner_opt.update(inner_grads, opt_state)
params = eqx.apply_updates(params, updates)
return (params, opt_state), None
# make a scan
init = (params, opt_state)
(final_params, final_opt_state), _ = eqxi.scan(inner_opt_step, init,
None, length=num_steps, kind=KIND)
return final_params, final_opt_state
def truncated_inner_optimization(params, opt_state, num_steps, num_steps_truncated):
"""Truncated inner optimization"""
@eqx.filter_custom_jvp
def wrapped_inner_optimization(params, opt_state, num_steps):
final_params, final_opt_state = inner_optimization(
params, opt_state, num_steps)
return final_params, final_opt_state
@wrapped_inner_optimization.def_jvp
def wrapped_inner_optimization_jvp(primals, tangents):
"""Truncated inner optimization"""
primal_out = wrapped_inner_optimization(*primals)
tangent_out = tangents[:2]
return primal_out, tangent_out
# Run the inner optimization with wrapped fn
final_params, final_opt_state = eqx.filter_jit(
wrapped_inner_optimization)(params, opt_state, num_steps_truncated)
# Run the remaining steps
final_params, final_opt_state = inner_optimization(
final_params, final_opt_state, num_steps - num_steps_truncated)
return final_params, final_opt_state
def outer_loss2(params):
"""Outer loss function"""
# Run optimization
final_params, _ = truncated_inner_optimization(
params, inner_opt_state, NSTEPS, NSTEPS - 100)
# Compute the loss
flat_final_params = jax.flatten_util.ravel_pytree(final_params)[0]
return jnp.sum(flat_final_params**3)
# Test
# create a pytree for the parameters
params = {'w': jax.random.normal(jax.random.PRNGKey(0), (1000, 1000))}
inner_opt = optax.adam(1e-3)
inner_opt_state = inner_opt.init(params)
outer_grads = eqx.filter_grad(outer_loss2)(params)
print(outer_grads)
When I checked the memory usage, it seems that this method is working but it feels hacky. Is there a better way to do this?
This looks pretty good to me! It's probably worth tweaking things slightly to have a single JIT wrapping the whole thing (including your final outer_grads = eqx.filter_grad(...), but otherwise I think this looks about as good as it gets :)
Hi @patrick-kidger
This is an even more minimal example. But this errors out.
import jax
import jax.flatten_util
import jax.numpy as jnp
import equinox as eqx
from collections import namedtuple
class Container(eqx.Module):
inner_model: eqx.Module
outer_model: namedtuple
def __init__(self, outer_param_dict):
self.inner_model = eqx.nn.MLP(1, 1, 5, 3, key=jax.random.PRNGKey(0))
# Convert dictionary to namedtuple
self.outer_model = namedtuple('OuterModel', outer_param_dict.keys())(**outer_param_dict)
def __call__(self, x):
return self.inner_model(x)
def new_inner_model(model):
"""Update the inner model"""
inner_params, inner_static = eqx.partition(model, inner_filter)
val = eqx.combine(inner_params, inner_static)(jnp.array([1.0]))
inner_params = jax.tree.map(lambda x: x*val, inner_params)
return eqx.combine(inner_params, inner_static)
def wrapper(fn):
"""Wrap to do truncated backprop"""
@eqx.filter_custom_jvp
def wrapped_fn(*args):
return fn(*args)
@wrapped_fn.def_jvp
def _jvp(primals, tangents):
primals_out = wrapped_fn(*primals)
return primals_out, tangents
return wrapped_fn
new_inner_model_wrapped = wrapper(new_inner_model)
def loss(outer_params, outer_static):
a = outer_params.outer_model.a
b = outer_params.outer_model.b
# Do something to the inner params
model = eqx.combine(outer_params, outer_static)
inner_param, inner_static = eqx.partition(model, inner_filter)
inner_param = jax.tree.map(lambda x: x*(a + b**2), inner_param)
model = eqx.combine(inner_param, inner_static)
# model = new_inner_model(model) # This works
model = new_inner_model_wrapped(model) # This doesn't work
inner_param, inner_static = eqx.partition(inner_param, inner_filter)
flat_inner_params = jax.flatten_util.ravel_pytree(inner_param)[0]
return a + b + jnp.sum(flat_inner_params) + a + b
# test
import jax.tree as jt
outer_param_dict = {'a': jnp.array(1.0), 'b': jnp.array(2.0)}
container = Container(outer_param_dict)
base_filter = jt.map(lambda _: False, container)
def _bilevel_igor_filters(model):
task_filter = jt.map(eqx.is_array, model.inner_model)
task_filter = eqx.tree_at(
lambda tree: tree.inner_model, base_filter, task_filter)
meta_filter = jt.map(eqx.is_array_like, model.outer_model)
meta_filter = eqx.tree_at(
lambda tree: tree.outer_model, base_filter, meta_filter)
return task_filter, meta_filter
inner_filter, outer_filter = _bilevel_igor_filters(container)
outer_params, outer_static = eqx.partition(container, outer_filter)
eqx.filter_grad(loss)(outer_params, outer_static).outer_model
Do you know what should be done to include this?
The output of a eqx.filter_custom_jvp must still consist only of JAX types. In this case you're returning the output of new_inner_model, which still has its static componets.
(I think)