dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

How to nest modules that have their own train() functions?

Open danijar opened this issue 3 years ago • 0 comments

I have a quite complex TF2 code base that I'm trying to convert to JAX / Haiku. There are many nested modules, some of which contain custom logic for updating their own parameters. Some of the modules are also used for computing the loss of other modules.

I figured out a way how to give the modules their own train() functions, illustrated in the example below. However, it's quite complicated because it requires hk.transforms inside each trainable module and does a lot of partitioning and merging of the global params dict (I couldn't figure out how update the values of module.get_params() within a module).

Is there a better way of doing this?

import haiku as hk
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
sg = jax.lax.stop_gradient


class Inner(hk.Module):

  def __init__(self, data_size, code_size):
    super().__init__()
    self.data_size = data_size
    self.code_size = code_size
    mean, std = jnp.zeros(code_size), jnp.ones(code_size)
    self.prior = tfd.Independent(tfd.Normal(mean, std), 1)
    self.tencode = hk.transform(self.encode)
    self.tdecode = hk.transform(self.decode)

  def build(self, x):
    self.decode(self.encode(x).mode())

  def train(self, params, x):
    current = hk.experimental.current_name()
    params, frozen = hk.data_structures.partition(
        lambda m, n, p: m == n.startswith(current), params)
    grad = jax.grad(self.loss)(params, frozen, x)
    params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad)
    return hk.data_structures.merge(params, frozen)

  def loss(self, params, frozen, x):
    params = hk.data_structures.merge(params, frozen)
    post = self.tencode.apply(params, None, x)
    recon = self.tdecode.apply(
        params, None, post.sample(seed=hk.next_rng_key()))
    mse = ((recon - x) ** 2).sum(-1).mean()
    kl = post.kl_divergence(self.prior).mean()
    return mse + 0.1 * kl

  def encode(self, x):
    assert x.shape[-1] == self.data_size, (x.shape, self.data_size)
    x = jax.nn.relu(hk.Linear(128)(x))
    x = jax.nn.relu(hk.Linear(128)(x))
    mean = hk.Linear(self.code_size)(x)
    std = jax.nn.softplus(hk.Linear(self.code_size)(x))
    return tfd.Independent(tfd.Normal(mean, std), 1)

  def decode(self, x):
    x = jax.nn.relu(hk.Linear(128)(x))
    x = jax.nn.relu(hk.Linear(128)(x))
    return hk.Linear(self.data_size)(x)


class Outer(hk.Module):

  def __init__(self, data_size, code_size):
    super().__init__()
    self.inner = Inner(data_size, code_size)
    self.tpred = hk.transform(self.pred)

  def build(self, x1, x2):
    self.inner.build(x1)
    self.pred(x2)

  def train(self, params, x1, x2):
    params = self.inner.train(params, x1)
    current = hk.experimental.current_name()
    params, frozen = hk.data_structures.partition(
        lambda m, n, p: m == n.startswith(current), params)
    grad = jax.grad(self.loss)(params, frozen, x1, x2)
    params = jax.tree_map(lambda w, g: w - 0.1 * g, params, grad)
    params = hk.data_structures.merge(params, frozen)
    return params

  def loss(self, params, frozen, x1, x2):
    params = hk.data_structures.merge(params, frozen)
    target = sg(self.inner.tencode.apply(params, None, x1).mode())
    return ((self.tpred.apply(params, None, x2) - target) ** 2).mean()

  def pred(self, x):
    x = jax.nn.relu(hk.Linear(128)(x))
    x = hk.Linear(self.inner.code_size)(x)
    return x


def main():
  x1 = jnp.zeros((64, 128), jnp.float32)
  x2 = jnp.zeros((64, 32), jnp.float32)

  def make():
    outer = Outer(128, 8)
    return outer.build, (outer.inner.loss, outer.loss, outer.pred, outer.train)
  init, (loss1, loss2, pred, train) = hk.multi_transform(make)

  rng = jax.random.PRNGKey(42)
  params = init(rng, x1, x2)
  print('Loss 1:', jax.jit(loss1)({}, rng, params, {}, x1))
  print('Loss 2:', jax.jit(loss2)({}, rng, params, {}, x1, x2))
  params = jax.jit(train)({}, rng, params, x1, x2)
  print('Pred:', jax.jit(pred)(params, rng, x2).mean())


if __name__ == '__main__':
  main()

danijar avatar May 21 '22 22:05 danijar