dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

Vmap over list of haiku functions

Open wcarvalho opened this issue 3 years ago • 8 comments

Discussed in https://github.com/google/jax/discussions/9413

Originally posted by wcarvalho February 2, 2022 Hello, first, thanks for the great tool.

I have a question about doing vmap over a list of haiku functions. This could be useful for example as an easy way to do multihead attention (though, I'm interested in other parallel computation as well).

Following this, I tried using hk.switch but got an error I don't understand.

Here is minimal code:

import jax
import jax.numpy as jnp
import haiku as hk

# create network + initialize parameters
def linear(x):

    functions = [hk.Linear(64) for i in range(8)]
    index = jnp.arange(len(functions))

    vmap_functions = jax.vmap(lambda i, x: hk.switch(i, functions, x))
    x = vmap_functions(index, x)
    
    return x

x = jnp.zeros((8, 10, 128))
net = hk.without_apply_rng(hk.transform(linear))
params = net.init(jax.random.PRNGKey(42), x)

y = net.apply(params, x)

and this is the error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Input In [3], in <module>
     12 x = jnp.zeros((8, 10, 128))
     13 net = hk.without_apply_rng(hk.transform(linear))
---> 14 params = net.init(jax.random.PRNGKey(42), x)
     16 y = net.apply(params, x)
     18 print(y.shape, jax.tree_map(lambda x: x.shape, params))

File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/transform.py:113, in without_state.<locals>.init_fn(*args, **kwargs)
    112 def init_fn(*args, **kwargs):
--> 113   params, state = f.init(*args, **kwargs)
    114   if state:
    115     raise ValueError("If your transformed function uses `hk.{get,set}_state` "
    116                      "then use `hk.transform_with_state`.")

File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/transform.py:364, in transform_with_state.<locals>.init_fn(rng, *args, **kwargs)
    362 with base.new_context(rng=rng) as ctx:
    363   try:
--> 364     f(*args, **kwargs)
    365   except jax.errors.UnexpectedTracerError as e:
    366     raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e

Input In [3], in linear(x)
      5 index = jnp.arange(len(functions))
      7 vmap_functions = jax.vmap(lambda i, x: hk.switch(i, functions, x))
----> 8 x = vmap_functions(index, x)
     10 return x

    [... skipping hidden 3 frame]

Input In [3], in linear.<locals>.<lambda>(i, x)
      4 functions = [hk.Linear(64) for i in range(8)]
      5 index = jnp.arange(len(functions))
----> 7 vmap_functions = jax.vmap(lambda i, x: hk.switch(i, functions, x))
      8 x = vmap_functions(index, x)
     10 return x

File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/stateful.py:447, in switch(index, branches, operand)
    445 stateful_branch_mem = _memoize_by_id(stateful_branch)
    446 state = internal_state()
--> 447 out, state = jax.lax.switch(
    448     index, tuple(map(stateful_branch_mem, branches)), (state, operand))
    449 update_internal_state(state)
    450 return out

    [... skipping hidden 2 frame]

File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py:2192, in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
   2185 """Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
   2186 
   2187 Corresponding `tree` and `avals` must match in the sense that the number of
   2188 leaves in `tree` must be equal to the length of `avals`. `what` will be
   2189 prepended to details of the mismatch in TypeError.
   2190 """
   2191 if tree1 != tree2:
-> 2192   raise TypeError(
   2193       f"{what} must have same type structure, got {tree1} and {tree2}.")
   2194 if not all(_map(core.typematch, avals1, avals2)):
   2195   diff = tree_multimap(_show_diff, tree_unflatten(tree1, avals1),
   2196                        tree_unflatten(tree2, avals2))

TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef((*, CustomNode(namedtuple[<class 'haiku._src.stateful.InternalState'>], [CustomNode(<class 'collections.defaultdict'>[(<class 'dict'>, ('linear',))], [{'b': *, 'w': *}]), CustomNode(<class 'collections.defaultdict'>[(<class 'dict'>, ())], []), (*, (*,))]))) and PyTreeDef((*, CustomNode(namedtuple[<class 'haiku._src.stateful.InternalState'>], [CustomNode(<class 'collections.defaultdict'>[(<class 'dict'>, ('linear_1',))], [{'b': *, 'w': *}]), CustomNode(<class 'collections.defaultdict'>[(<class 'dict'>, ())], []), (*, (*,))]))).

wcarvalho avatar Feb 02 '22 16:02 wcarvalho

Here is a solution I found using hk.lift. I am very new to this whole vmap/haiku thing so I would really appreciate feedback on whether this is kosher or will produce a bug. For example, I don't know if I should be calling hk.running_init anywhere?

import jax
import jax.numpy as jnp
import haiku as hk
def linear(x):

    # create 8 linear functions to apply in parallel
    functions = [hk.Linear(64) for i in range(8)]
    # nested transforms for all 8 + extract inits + applies
    init_applies = jax.tree_map(lambda net: hk.transform(net), functions)  # nested transform
    inits = [i[0] for i in init_applies]
    applies = [i[1] for i in init_applies]

    # get lifted parameters for each linear
    params = [hk.lift(init, name="inner")(hk.next_rng_key(), x) for init in inits]

    # take first set of parameters and change to mutable dict.
    # will be used for parallel processing along with first apply fn
    storage = hk.data_structures.to_mutable_dict(params[0])

    # stack all the parameters so they're in one dict 
    # need to do this because they have different "types"?
    # here, "linear", "linear1", etc. Don't really understand.
    # > print(jax.tree_map(lambda x:x.shape, params)): 
    # >   FlatMap({'linear': FlatMap({'b': (64,), 'w': (128, 64)})}), FlatMap({'linear_1': FlatMap({'b': (64,), 'w': (128, 64)})}), ...
    just_params = [next(iter(p.values())) for p in params]

    # stack params and place in first params dict
    stacked_params = jax.tree_map(lambda *arrays: jnp.stack(arrays), *just_params)
    storage['linear'] = stacked_params

    
    def apply_fn(params, x):
        return applies[0](params, hk.next_rng_key(), x)
    
    # paralell over 0st dim of params, 1st dim of data
    x = jax.vmap(apply_fn, in_axes=(0, 1))(storage, x)
    
    return x

x = jnp.zeros((10, 8, 128))
init, apply = hk.transform(linear)
params = init(jax.random.PRNGKey(42), x)

y = apply(params, jax.random.PRNGKey(42), x)


print(y.shape)
print(jax.tree_map(lambda x: x.shape, params))

Output

(8, 10, 64)
FlatMap({
  'inner/linear': FlatMap({'b': (64,), 'w': (128, 64)}),
  'inner_1/linear_1': FlatMap({'b': (64,), 'w': (128, 64)}),
  'inner_2/linear_2': FlatMap({'b': (64,), 'w': (128, 64)}),
  'inner_3/linear_3': FlatMap({'b': (64,), 'w': (128, 64)}),
  'inner_4/linear_4': FlatMap({'b': (64,), 'w': (128, 64)}),
  'inner_5/linear_5': FlatMap({'b': (64,), 'w': (128, 64)}),
  'inner_6/linear_6': FlatMap({'b': (64,), 'w': (128, 64)}),
  'inner_7/linear_7': FlatMap({'b': (64,), 'w': (128, 64)}),
})

wcarvalho avatar Feb 02 '22 17:02 wcarvalho

Hi, in this case you can avoid using hk.lift if you would prefer.

We're missing documentation here (I will fix shortly, an internal user pointed this out earlier today :smile:) but hk.switch requires you to create parameters in all branches before calling it. The easiest way to do this is to unconditionally call all branches during init, and use the switch during apply:

if hk.running_init():
  # Create all parameters in all branches.
  example = x
  for branch in branches:
     x = branch(example)
else:
  x = hk.switch(index, branches, x)

I've reworked your code to do so and it now runs correctly in colab for me: https://colab.research.google.com/gist/tomhennigan/86fe7b7d46a930b58a5d8d76a75d75ea/fix-for-300.ipynb

tomhennigan avatar Feb 02 '22 17:02 tomhennigan

Thank you for the prompt response! This solved my problem!

wcarvalho avatar Feb 02 '22 18:02 wcarvalho

FYI to anyone that comes here. You need to add split_rng=(not hk.running_init()) for hk.vmap

wcarvalho avatar Feb 05 '22 23:02 wcarvalho

Question, the following for loop seems to be slowing down compilation signifcantly

functions = [hk.Linear(1) for i in range(8)]

Any thoughts on how to speed this up?

wcarvalho avatar Feb 13 '22 23:02 wcarvalho

Compile time for a program of that size should only be 2-3 seconds, and this is a cost you should only pay once per process (unless you are frequently changing input shape).

Perhaps you've not added jax.jit around the call to apply? This should significantly speed up your program. Here is an updated example:

https://colab.research.google.com/gist/tomhennigan/31db456baad6d6a65dd4f53e7862aa6a/fix-for-300.ipynb

tomhennigan avatar Feb 14 '22 08:02 tomhennigan

Hello Tom, my apologies for the late response. I'm finding that this has a long compile time when used to create parameters for the DQN loss objective in the ACME library.

  1. where jit happens
  2. loss function details

Compilation will take up to an hour.

wcarvalho avatar Mar 02 '22 23:03 wcarvalho

I think I figured out the problem. Inside an RNN, this is SLOWER than just using a for loop @tomhennigan

wcarvalho avatar Mar 30 '22 21:03 wcarvalho