dm-haiku
dm-haiku copied to clipboard
How to use `multi_transform_with_state` with `is_training`?
I am currently using this script to setup my model:
import haiku as hk
import jax
import numpy as np
import jax.numpy as jnp
key = jax.random.PRNGKey(1)
class MyModule(hk.Module):
def f(self, x, is_training):
x = hk.Linear(2)(x)
x = hk.BatchNorm(True, True, 0.9)(x, is_training)
x = jax.nn.relu(x)
return x
def g(self, x, is_training):
x = hk.Linear(2)(x)
x = hk.BatchNorm(True, True, 0.9)(x, is_training)
x = jax.nn.relu(x)
return x
def multi_transform_target():
module = MyModule()
def module_walk(inputs):
x = module.f(inputs, is_training=True)
x = module.f(inputs, is_training=False)
y = module.g(x, is_training=True)
y = module.g(x, is_training=False)
return (x, y)
return module_walk, (module.f, module.g)
hk_transformed = hk.multi_transform_with_state(multi_transform_target)
params, state = hk_transformed.init(key, jnp.ones(2))
Is this the right way to do it? or I should pass is_training
to module_walk
instead?