equinox
equinox copied to clipboard
Recover optimised model parameters as a dictionary?
As a minimal example, lets say we have something like this (the real situation is custom modules which contain other modules, hence the flatten). If I look at flat_model obviously this holds the optimised values of Z and U, but they aren't labelled, where as the treedef_model holds labelled the structure.
I presume I am missing something really simple here, but I want to construct a dictionary of the resultant optimised params so that I can feed them into another model.
import equinox as eqx
import jax.numpy as jnp
import jax.tree as jtu
import optax
class Test_Model(eqx.Module):
Z: jnp.ndarray
U: jnp.ndarray
def __init__(self,
params):
self.U = params['U']
self.Z = params['Z']
def __call__(self, *args, **kwargs):
...do stuff
params = { 'Z': jnp.zeros((10, 2)), 'U': jnp.zeros((10, 1)) }
model = Test_Model(params)
opt_init, opt_update = optax.adabelief(LR=0.01)
opt_state = opt_init(eqx.filter(model, eqx.is_array))
flat_model, treedef_model = jtu.tree_flatten(model)
flat_opt_state, treedef_opt_state = jtu.tree_flatten(opt_state)
@eqx.filter_value_and_grad
def compute_loss(model, data):
...some loss
@eqx.filter_jit
def make_step(flat_model, data, flat_opt_state):
model = jtu.tree_unflatten(treedef_model, flat_model)
opt_state = jtu.tree_unflatten(treedef_opt_state, flat_opt_state)
loss, grads = compute_loss(model, data)
updates, opt_state = opt_update(grads, opt_state)
model = eqx.apply_updates(model, updates)
flat_model = jtu.tree_leaves(model)
flat_opt_state = jtu.tree_leaves(opt_state)
return loss, flat_model, flat_opt_state
for step in range(100):
loss, flat_model, flat_opt_state = make_step(flat_model, data, flat_opt_state)
# Recover the values of the model parameters from the flat model
# such that I have a dictionary in the form of original "params"
opt_params = ????
# such that opt_params = {'Z' : ..... ' , 'U' : .....}