equinox icon indicating copy to clipboard operation
equinox copied to clipboard

Recover optimised model parameters as a dictionary?

Open adam-hartshorne opened this issue 1 year ago • 4 comments

As a minimal example, lets say we have something like this (the real situation is custom modules which contain other modules, hence the flatten). If I look at flat_model obviously this holds the optimised values of Z and U, but they aren't labelled, where as the treedef_model holds labelled the structure.

I presume I am missing something really simple here, but I want to construct a dictionary of the resultant optimised params so that I can feed them into another model.

import equinox as eqx
import jax.numpy as jnp
import jax.tree as jtu
import optax


class Test_Model(eqx.Module):

    Z: jnp.ndarray
    U: jnp.ndarray

    def __init__(self,
                 params):

        self.U = params['U']
        self.Z = params['Z']

    def __call__(self, *args, **kwargs):
        ...do stuff


params = { 'Z': jnp.zeros((10, 2)), 'U': jnp.zeros((10, 1)) }


model = Test_Model(params)
opt_init, opt_update = optax.adabelief(LR=0.01)
opt_state = opt_init(eqx.filter(model, eqx.is_array))
flat_model, treedef_model = jtu.tree_flatten(model)
flat_opt_state, treedef_opt_state = jtu.tree_flatten(opt_state)


@eqx.filter_value_and_grad
def compute_loss(model, data):
    ...some loss

@eqx.filter_jit
def make_step(flat_model, data, flat_opt_state):
    model = jtu.tree_unflatten(treedef_model, flat_model)
    opt_state = jtu.tree_unflatten(treedef_opt_state, flat_opt_state)
    loss, grads = compute_loss(model, data)
    updates, opt_state = opt_update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    flat_model = jtu.tree_leaves(model)
    flat_opt_state = jtu.tree_leaves(opt_state)
    return loss, flat_model, flat_opt_state


for step in range(100):
    loss, flat_model, flat_opt_state = make_step(flat_model, data, flat_opt_state)


# Recover the values of the model parameters from the flat model 
# such that I have a dictionary in the form of original "params"
opt_params = ????
# such that opt_params = {'Z' : ..... ' , 'U' : .....}

adam-hartshorne avatar Jan 31 '23 17:01 adam-hartshorne