dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

Faithfully reconstruct tree from context

Open mattwescott opened this issue 5 years ago • 0 comments

This variation on @tomhennigan's example tries to build a tree of module types.

It assumes the parameter creation order is preserved when flattening the parameter dictionary, which may be incorrect. Alternatively, if the path could be added to context, or if it is possible to recover the path from context, that would support a more satisfying solution. With module names and parameters possibly containing "/", it is not clear to me how to construct the path. What am I missing?

def init_and_build_module_tree(f):
    """
    Decorated functions build a tree of module types alongside the parameters

    Usage:
      def f(x):
        net = haiku.nets.MLP([300, 100, 10])
        return net(x)

      params, modules = init_and_build_module_tree(f)(rng_key, np.zeros(4))
      params = tree.map_structure(transform_params, params, modules)
    """

    def _init_and_build_module_tree(rng_key, *args, **kwargs):
        module_types = []

        def record_module_type(next_creator, shape, dtype, init, context):
            module_types.append(type(context.module))
            return next_creator(shape, dtype, init)

        def with_creator(*aargs, **kkwargs):
            with haiku.experimental.custom_creator(record_module_type):
                return f(*aargs, **kkwargs)

        params, _ = haiku.transform_with_state(with_creator).init(
            rng_key,
            *args,
            **kwargs
        )

        module_tree = tree.unflatten_as(
            params,
            module_types
        )

        return params, module_tree

    return _init_and_build_module_tree

mattwescott avatar May 11 '20 14:05 mattwescott