dm-haiku icon indicating copy to clipboard operation
dm-haiku copied to clipboard

AttributeError: 'Regressor' object has no attribute '_auto_repr'

Open danijar opened this issue 3 years ago • 0 comments

I'm trying to get the following code to work but getting the mentioned error (full stack trace below). Putting hk.experimental.module_auto_repr(False) at the top of the file avoids this error (throwing a new unrelated error), so I'm assuming this is a bug in Haiku?

import haiku as hk
import jax
import jax.numpy as jnp
import optax


class Regressor(hk.Module):

  def __init__(self):
    super().__init__()
    self.mlp = hk.nets.MLP([128, 128, 8])
    self.opt = Optimizer(self.loss, self.mlp)

  def __call__(self, data):
    return self.mlp(data['x'])

  def build(self, data):
    self.mlp(data['x'])

  def train(self, params, data):
    params = self.opt(params, data)

  def loss(self, data):
    return ((self.mlp(data['x']) - data['y']) ** 2).mean()


class Distiller(hk.Module):

  def __init__(self):
    super().__init__()
    self.target = Regressor()
    self.mlp = hk.nets.MLP([128, 128, 8])
    self.opt = Optimizer(self.loss, self.mlp)

  def build(self, data):
    self.target.build(data)
    self.mlp(data['x'])

  def train(self, params, data):
    params = self.target.train(params, data)
    params = self.opt(params, data)
    return params

  def loss(self, data):
    target = jax.lax.stop_gradient(self.target(data))
    return ((self.mlp(data['x']) - target) ** 2).mean()


class Optimizer(hk.Module):

  def __init__(self, loss, module, lr=1e-4):
    super().__init__()
    self.tloss = hk.transform(loss)
    self.module = module
    self.opt = optax.adam(lr)
    # TODO: Does this work? Will it be included in checkpoints?
    self.state = self.opt.init(self.module.params_dict())

  def __call__(self, params, *a, **k):
    module_params = self.module.params_dict()
    print(dir(self.module))
    print('MODULE PARAMS:', module_params)
    trainable, frozen = hk.data_structures.partition(
        lambda n, m, p: n in module_params, params)
    print('TRAINABLE:', trainable)
    def inner(trainable, frozen, *a, **k):
      params = hk.data_structures.merge(frozen, trainable)
      return self.tloss.apply(params, None, *a, **k)
    grads = hk.grad(inner)(trainable, frozen, *a, **k)
    updates, self.state = self.opt.update(grads, self.state)
    trainable = optax.apply_updates(trainable, updates)
    params = hk.data_structures.merge(frozen, trainable)
    # TODO: Does this work?
    # self.module.params_dict().update(params)
    return params


def main():
  def f():
    module = Distiller()
    def init(data):
      module.build(data)
    return init, (module.train, module.loss, module.mlp)
  init, (train, loss, pred) = hk.multi_transform(f)
  train, loss, pred = jax.jit(train), jax.jit(loss), jax.jit(pred)
  data = {'x': jnp.zeros((16, 64)), 'y': jnp.zeros((16, 8))}
  rng = jax.random.PRNGKey(42)
  params = init(rng, data)
  for step in range(10):
    params = train({}, rng, params, data)
    print('Loss:', loss(params, rng, data))
    print('Pred:', pred(params, rng, data).mean())


if __name__ == '__main__':
  main()
Traceback (most recent call last):
  File "/Users/danijar/temp/modules.py", line 96, in <module>
    main()
  File "/Users/danijar/temp/modules.py", line 84, in main
    init, (train, loss, pred) = hk.multi_transform(f)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 229, in multi_transform
    f = multi_transform_with_state(f)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 163, in multi_transform_with_state
    output_treedef = jax.eval_shape(get_output_treedef).python_value
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/_src/api.py", line 2967, in eval_shape
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 524, in abstract_eval_fun
    _, avals_out, _ = trace_to_jaxpr_dynamic(
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
    return func(*args, **kwargs)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
    jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 160, in get_output_treedef
    apply_fns, _ = fns.apply(*fns.init(rng), rng)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/transform.py", line 335, in init_fn
    f(*args, **kwargs)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 159, in <lambda>
    fns = hk.transform_with_state(lambda: f()[1])
  File "/Users/danijar/temp/modules.py", line 80, in f
    module = Distiller()
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 120, in __call__
    init(module, *args, **kwargs)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 416, in wrapped
    out = f(*args, **kwargs)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 266, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/danijar/temp/modules.py", line 31, in __init__
    self.target = Regressor()
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 120, in __call__
    init(module, *args, **kwargs)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 416, in wrapped
    out = f(*args, **kwargs)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 266, in run_interceptors
    return bound_method(*args, **kwargs)
  File "/Users/danijar/temp/modules.py", line 12, in __init__
    self.opt = Optimizer(self.loss, self.mlp)
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 124, in __call__
    module._auto_repr = utils.auto_repr(cls, *args, **kwargs)  # pylint: disable=protected-access
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/utils.py", line 87, in auto_repr
    single_line = cls.__name__ + "({})".format(", ".join(
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/utils.py", line 88, in <genexpr>
    name + repr(value) for name, value in names_and_values))
  File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 90, in <lambda>
    lambda module: module._auto_repr)  # pylint: disable=protected-access
AttributeError: 'Regressor' object has no attribute '_auto_repr'

danijar avatar May 22 '22 20:05 danijar