dm-haiku
dm-haiku copied to clipboard
hk.switch does not work inside a hk.vmap function when hk.set_state is used
Hi,
I have been using Haiku (amazing tool, btw!) for about 8 months now. Up until now, I used to wrap my custom haiku modules around hk.transform. Inside my module, I vmapped a function (using hk.vmap) that contained a hk.switch statement (to evaluate a chosen branch function).
I recently moved to using stateful modules, which needs the hk.transform_with_state transform. I also need to keep track of a specific value over time in my machine learning model that is not to be updated by the optimizer. For this, I am using hk.set_state("name", val) to store it and access it later. However, as soon as I use any kind of set_state call anywhere in the model, the vmapped function fails with the error
ValueError: vmap has mapped output but out_axes is None
Is there any way to use hk.switch inside a hk.vmap function when hk.set_state is used in the module?
Thank you.
Code
import haiku as hk
import numpy as np
import jax
import jax.numpy as jnp
def forward(input_):
# hk.set_state("some_state_val", 2) # line that causes the problem
def func(i, x):
temp = hk.switch(i.squeeze(), applys, x)
return temp
func_vmap = hk.vmap(func, in_axes=(0, None), split_rng=True)
applys = []
for i in range(10):
temp_apply = lambda x: x**2
applys.append(temp_apply)
pred = func_vmap(input_[1], input_[0])
return pred
rng_key = jax.random.PRNGKey(4)
stateful_forward = hk.transform_with_state(forward)
data = jnp.asarray(np.random.rand(100, 30))
idx = jnp.asarray(np.random.randint(0, 10, (100, 1))).astype(jnp.int32)
inp = [data, idx]
init, apply = stateful_forward.init, stateful_forward.apply
init = jax.jit(init)
apply = jax.jit(apply)
params, state = init(rng_key, inp)
output, state = apply(params, state, rng_key, inp)