dm-haiku
dm-haiku copied to clipboard
Faithfully reconstruct tree from context
This variation on @tomhennigan's example tries to build a tree of module types.
It assumes the parameter creation order is preserved when flattening the parameter dictionary, which may be incorrect. Alternatively, if the path could be added to context, or if it is possible to recover the path from context, that would support a more satisfying solution. With module names and parameters possibly containing "/", it is not clear to me how to construct the path. What am I missing?
def init_and_build_module_tree(f):
"""
Decorated functions build a tree of module types alongside the parameters
Usage:
def f(x):
net = haiku.nets.MLP([300, 100, 10])
return net(x)
params, modules = init_and_build_module_tree(f)(rng_key, np.zeros(4))
params = tree.map_structure(transform_params, params, modules)
"""
def _init_and_build_module_tree(rng_key, *args, **kwargs):
module_types = []
def record_module_type(next_creator, shape, dtype, init, context):
module_types.append(type(context.module))
return next_creator(shape, dtype, init)
def with_creator(*aargs, **kkwargs):
with haiku.experimental.custom_creator(record_module_type):
return f(*aargs, **kkwargs)
params, _ = haiku.transform_with_state(with_creator).init(
rng_key,
*args,
**kwargs
)
module_tree = tree.unflatten_as(
params,
module_types
)
return params, module_tree
return _init_and_build_module_tree