dm-haiku
dm-haiku copied to clipboard
How to nest modules that have their own train() functions?
I have a quite complex TF2 code base that I'm trying to convert to JAX / Haiku. There are many nested modules, some of which contain custom logic for updating their own parameters. Some of the modules are also used for computing the loss of other modules.
I figured out a way how to give the modules their own train() functions, illustrated in the example below. However, it's quite complicated because it requires hk.transforms inside each trainable module and does a lot of partitioning and merging of the global params dict (I couldn't figure out how update the values of module.get_params() within a module).
Is there a better way of doing this?
import haiku as hk
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
sg = jax.lax.stop_gradient
class Inner(hk.Module):
def __init__(self, data_size, code_size):
super().__init__()
self.data_size = data_size
self.code_size = code_size
mean, std = jnp.zeros(code_size), jnp.ones(code_size)
self.prior = tfd.Independent(tfd.Normal(mean, std), 1)
self.tencode = hk.transform(self.encode)
self.tdecode = hk.transform(self.decode)
def build(self, x):
self.decode(self.encode(x).mode())
def train(self, params, x):
current = hk.experimental.current_name()
params, frozen = hk.data_structures.partition(
lambda m, n, p: m == n.startswith(current), params)
grad = jax.grad(self.loss)(params, frozen, x)
params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad)
return hk.data_structures.merge(params, frozen)
def loss(self, params, frozen, x):
params = hk.data_structures.merge(params, frozen)
post = self.tencode.apply(params, None, x)
recon = self.tdecode.apply(
params, None, post.sample(seed=hk.next_rng_key()))
mse = ((recon - x) ** 2).sum(-1).mean()
kl = post.kl_divergence(self.prior).mean()
return mse + 0.1 * kl
def encode(self, x):
assert x.shape[-1] == self.data_size, (x.shape, self.data_size)
x = jax.nn.relu(hk.Linear(128)(x))
x = jax.nn.relu(hk.Linear(128)(x))
mean = hk.Linear(self.code_size)(x)
std = jax.nn.softplus(hk.Linear(self.code_size)(x))
return tfd.Independent(tfd.Normal(mean, std), 1)
def decode(self, x):
x = jax.nn.relu(hk.Linear(128)(x))
x = jax.nn.relu(hk.Linear(128)(x))
return hk.Linear(self.data_size)(x)
class Outer(hk.Module):
def __init__(self, data_size, code_size):
super().__init__()
self.inner = Inner(data_size, code_size)
self.tpred = hk.transform(self.pred)
def build(self, x1, x2):
self.inner.build(x1)
self.pred(x2)
def train(self, params, x1, x2):
params = self.inner.train(params, x1)
current = hk.experimental.current_name()
params, frozen = hk.data_structures.partition(
lambda m, n, p: m == n.startswith(current), params)
grad = jax.grad(self.loss)(params, frozen, x1, x2)
params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad)
params = hk.data_structures.merge(params, frozen)
return params
def loss(self, params, frozen, x1, x2):
params = hk.data_structures.merge(params, frozen)
target = sg(self.inner.tencode.apply(params, None, x1).mode())
return ((self.tpred.apply(params, None, x2) - target) ** 2).mean()
def pred(self, x):
x = jax.nn.relu(hk.Linear(128)(x))
x = hk.Linear(self.inner.code_size)(x)
return x
def main():
x1 = jnp.zeros((64, 128), jnp.float32)
x2 = jnp.zeros((64, 32), jnp.float32)
def make():
outer = Outer(128, 8)
return outer.build, (outer.inner.loss, outer.loss, outer.pred, outer.train)
init, (loss1, loss2, pred, train) = hk.multi_transform(make)
rng = jax.random.PRNGKey(42)
params = init(rng, x1, x2)
print('Loss 1:', jax.jit(loss1)({}, rng, params, {}, x1))
print('Loss 2:', jax.jit(loss2)({}, rng, params, {}, x1, x2))
params = jax.jit(train)({}, rng, params, x1, x2)
print('Pred:', jax.jit(pred)(params, rng, x2).mean())
if __name__ == '__main__':
main()