dm-haiku
dm-haiku copied to clipboard
Vmap over list of haiku functions
Discussed in https://github.com/google/jax/discussions/9413
Originally posted by wcarvalho February 2, 2022 Hello, first, thanks for the great tool.
I have a question about doing vmap over a list of haiku functions. This could be useful for example as an easy way to do multihead attention (though, I'm interested in other parallel computation as well).
Following this, I tried using hk.switch
but got an error I don't understand.
Here is minimal code:
import jax
import jax.numpy as jnp
import haiku as hk
# create network + initialize parameters
def linear(x):
functions = [hk.Linear(64) for i in range(8)]
index = jnp.arange(len(functions))
vmap_functions = jax.vmap(lambda i, x: hk.switch(i, functions, x))
x = vmap_functions(index, x)
return x
x = jnp.zeros((8, 10, 128))
net = hk.without_apply_rng(hk.transform(linear))
params = net.init(jax.random.PRNGKey(42), x)
y = net.apply(params, x)
and this is the error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Input In [3], in <module>
12 x = jnp.zeros((8, 10, 128))
13 net = hk.without_apply_rng(hk.transform(linear))
---> 14 params = net.init(jax.random.PRNGKey(42), x)
16 y = net.apply(params, x)
18 print(y.shape, jax.tree_map(lambda x: x.shape, params))
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/transform.py:113, in without_state.<locals>.init_fn(*args, **kwargs)
112 def init_fn(*args, **kwargs):
--> 113 params, state = f.init(*args, **kwargs)
114 if state:
115 raise ValueError("If your transformed function uses `hk.{get,set}_state` "
116 "then use `hk.transform_with_state`.")
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/transform.py:364, in transform_with_state.<locals>.init_fn(rng, *args, **kwargs)
362 with base.new_context(rng=rng) as ctx:
363 try:
--> 364 f(*args, **kwargs)
365 except jax.errors.UnexpectedTracerError as e:
366 raise jax.errors.UnexpectedTracerError(unexpected_tracer_hint) from e
Input In [3], in linear(x)
5 index = jnp.arange(len(functions))
7 vmap_functions = jax.vmap(lambda i, x: hk.switch(i, functions, x))
----> 8 x = vmap_functions(index, x)
10 return x
[... skipping hidden 3 frame]
Input In [3], in linear.<locals>.<lambda>(i, x)
4 functions = [hk.Linear(64) for i in range(8)]
5 index = jnp.arange(len(functions))
----> 7 vmap_functions = jax.vmap(lambda i, x: hk.switch(i, functions, x))
8 x = vmap_functions(index, x)
10 return x
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/haiku/_src/stateful.py:447, in switch(index, branches, operand)
445 stateful_branch_mem = _memoize_by_id(stateful_branch)
446 state = internal_state()
--> 447 out, state = jax.lax.switch(
448 index, tuple(map(stateful_branch_mem, branches)), (state, operand))
449 update_internal_state(state)
450 return out
[... skipping hidden 2 frame]
File ~/miniconda3/envs/acmejax/lib/python3.9/site-packages/jax/_src/lax/control_flow.py:2192, in _check_tree_and_avals(what, tree1, avals1, tree2, avals2)
2185 """Raises TypeError if (tree1, avals1) does not match (tree2, avals2).
2186
2187 Corresponding `tree` and `avals` must match in the sense that the number of
2188 leaves in `tree` must be equal to the length of `avals`. `what` will be
2189 prepended to details of the mismatch in TypeError.
2190 """
2191 if tree1 != tree2:
-> 2192 raise TypeError(
2193 f"{what} must have same type structure, got {tree1} and {tree2}.")
2194 if not all(_map(core.typematch, avals1, avals2)):
2195 diff = tree_multimap(_show_diff, tree_unflatten(tree1, avals1),
2196 tree_unflatten(tree2, avals2))
TypeError: branch 0 and 1 outputs must have same type structure, got PyTreeDef((*, CustomNode(namedtuple[<class 'haiku._src.stateful.InternalState'>], [CustomNode(<class 'collections.defaultdict'>[(<class 'dict'>, ('linear',))], [{'b': *, 'w': *}]), CustomNode(<class 'collections.defaultdict'>[(<class 'dict'>, ())], []), (*, (*,))]))) and PyTreeDef((*, CustomNode(namedtuple[<class 'haiku._src.stateful.InternalState'>], [CustomNode(<class 'collections.defaultdict'>[(<class 'dict'>, ('linear_1',))], [{'b': *, 'w': *}]), CustomNode(<class 'collections.defaultdict'>[(<class 'dict'>, ())], []), (*, (*,))]))).
Here is a solution I found using hk.lift
. I am very new to this whole vmap/haiku thing so I would really appreciate feedback on whether this is kosher or will produce a bug.
For example, I don't know if I should be calling hk.running_init
anywhere?
import jax
import jax.numpy as jnp
import haiku as hk
def linear(x):
# create 8 linear functions to apply in parallel
functions = [hk.Linear(64) for i in range(8)]
# nested transforms for all 8 + extract inits + applies
init_applies = jax.tree_map(lambda net: hk.transform(net), functions) # nested transform
inits = [i[0] for i in init_applies]
applies = [i[1] for i in init_applies]
# get lifted parameters for each linear
params = [hk.lift(init, name="inner")(hk.next_rng_key(), x) for init in inits]
# take first set of parameters and change to mutable dict.
# will be used for parallel processing along with first apply fn
storage = hk.data_structures.to_mutable_dict(params[0])
# stack all the parameters so they're in one dict
# need to do this because they have different "types"?
# here, "linear", "linear1", etc. Don't really understand.
# > print(jax.tree_map(lambda x:x.shape, params)):
# > FlatMap({'linear': FlatMap({'b': (64,), 'w': (128, 64)})}), FlatMap({'linear_1': FlatMap({'b': (64,), 'w': (128, 64)})}), ...
just_params = [next(iter(p.values())) for p in params]
# stack params and place in first params dict
stacked_params = jax.tree_map(lambda *arrays: jnp.stack(arrays), *just_params)
storage['linear'] = stacked_params
def apply_fn(params, x):
return applies[0](params, hk.next_rng_key(), x)
# paralell over 0st dim of params, 1st dim of data
x = jax.vmap(apply_fn, in_axes=(0, 1))(storage, x)
return x
x = jnp.zeros((10, 8, 128))
init, apply = hk.transform(linear)
params = init(jax.random.PRNGKey(42), x)
y = apply(params, jax.random.PRNGKey(42), x)
print(y.shape)
print(jax.tree_map(lambda x: x.shape, params))
Output
(8, 10, 64)
FlatMap({
'inner/linear': FlatMap({'b': (64,), 'w': (128, 64)}),
'inner_1/linear_1': FlatMap({'b': (64,), 'w': (128, 64)}),
'inner_2/linear_2': FlatMap({'b': (64,), 'w': (128, 64)}),
'inner_3/linear_3': FlatMap({'b': (64,), 'w': (128, 64)}),
'inner_4/linear_4': FlatMap({'b': (64,), 'w': (128, 64)}),
'inner_5/linear_5': FlatMap({'b': (64,), 'w': (128, 64)}),
'inner_6/linear_6': FlatMap({'b': (64,), 'w': (128, 64)}),
'inner_7/linear_7': FlatMap({'b': (64,), 'w': (128, 64)}),
})
Hi, in this case you can avoid using hk.lift
if you would prefer.
We're missing documentation here (I will fix shortly, an internal user pointed this out earlier today :smile:) but hk.switch
requires you to create parameters in all branches before calling it. The easiest way to do this is to unconditionally call all branches during init, and use the switch during apply:
if hk.running_init():
# Create all parameters in all branches.
example = x
for branch in branches:
x = branch(example)
else:
x = hk.switch(index, branches, x)
I've reworked your code to do so and it now runs correctly in colab for me: https://colab.research.google.com/gist/tomhennigan/86fe7b7d46a930b58a5d8d76a75d75ea/fix-for-300.ipynb
Thank you for the prompt response! This solved my problem!
FYI to anyone that comes here. You need to add split_rng=(not hk.running_init())
for hk.vmap
Question, the following for loop seems to be slowing down compilation signifcantly
functions = [hk.Linear(1) for i in range(8)]
Any thoughts on how to speed this up?
Compile time for a program of that size should only be 2-3 seconds, and this is a cost you should only pay once per process (unless you are frequently changing input shape).
Perhaps you've not added jax.jit
around the call to apply? This should significantly speed up your program. Here is an updated example:
https://colab.research.google.com/gist/tomhennigan/31db456baad6d6a65dd4f53e7862aa6a/fix-for-300.ipynb
Hello Tom, my apologies for the late response. I'm finding that this has a long compile time when used to create parameters for the DQN loss objective in the ACME library.
Compilation will take up to an hour.
I think I figured out the problem. Inside an RNN, this is SLOWER than just using a for loop @tomhennigan