dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

How to use `multi_transform_with_state` with `is_training`?

Open uduse opened this issue 2 years ago • 0 comments

I am currently using this script to setup my model:

import haiku as hk
import jax
import numpy as np
import jax.numpy as jnp
key = jax.random.PRNGKey(1)

class MyModule(hk.Module):
    def f(self, x, is_training):
        x = hk.Linear(2)(x)
        x = hk.BatchNorm(True, True, 0.9)(x, is_training)
        x = jax.nn.relu(x)
        return x
        
    def g(self, x, is_training):
        x = hk.Linear(2)(x)
        x = hk.BatchNorm(True, True, 0.9)(x, is_training)
        x = jax.nn.relu(x)
        return x


def multi_transform_target():
    module = MyModule()

    def module_walk(inputs):
        x = module.f(inputs, is_training=True)
        x = module.f(inputs, is_training=False)
        y = module.g(x, is_training=True)
        y = module.g(x, is_training=False)
        return (x, y)

    return module_walk, (module.f, module.g)


hk_transformed = hk.multi_transform_with_state(multi_transform_target)
params, state = hk_transformed.init(key, jnp.ones(2))

Is this the right way to do it? or I should pass is_training to module_walk instead?

uduse avatar Jul 13 '22 02:07 uduse