dm-haiku
dm-haiku copied to clipboard
AttributeError: 'Regressor' object has no attribute '_auto_repr'
I'm trying to get the following code to work but getting the mentioned error (full stack trace below). Putting hk.experimental.module_auto_repr(False) at the top of the file avoids this error (throwing a new unrelated error), so I'm assuming this is a bug in Haiku?
import haiku as hk
import jax
import jax.numpy as jnp
import optax
class Regressor(hk.Module):
def __init__(self):
super().__init__()
self.mlp = hk.nets.MLP([128, 128, 8])
self.opt = Optimizer(self.loss, self.mlp)
def __call__(self, data):
return self.mlp(data['x'])
def build(self, data):
self.mlp(data['x'])
def train(self, params, data):
params = self.opt(params, data)
def loss(self, data):
return ((self.mlp(data['x']) - data['y']) ** 2).mean()
class Distiller(hk.Module):
def __init__(self):
super().__init__()
self.target = Regressor()
self.mlp = hk.nets.MLP([128, 128, 8])
self.opt = Optimizer(self.loss, self.mlp)
def build(self, data):
self.target.build(data)
self.mlp(data['x'])
def train(self, params, data):
params = self.target.train(params, data)
params = self.opt(params, data)
return params
def loss(self, data):
target = jax.lax.stop_gradient(self.target(data))
return ((self.mlp(data['x']) - target) ** 2).mean()
class Optimizer(hk.Module):
def __init__(self, loss, module, lr=1e-4):
super().__init__()
self.tloss = hk.transform(loss)
self.module = module
self.opt = optax.adam(lr)
# TODO: Does this work? Will it be included in checkpoints?
self.state = self.opt.init(self.module.params_dict())
def __call__(self, params, *a, **k):
module_params = self.module.params_dict()
print(dir(self.module))
print('MODULE PARAMS:', module_params)
trainable, frozen = hk.data_structures.partition(
lambda n, m, p: n in module_params, params)
print('TRAINABLE:', trainable)
def inner(trainable, frozen, *a, **k):
params = hk.data_structures.merge(frozen, trainable)
return self.tloss.apply(params, None, *a, **k)
grads = hk.grad(inner)(trainable, frozen, *a, **k)
updates, self.state = self.opt.update(grads, self.state)
trainable = optax.apply_updates(trainable, updates)
params = hk.data_structures.merge(frozen, trainable)
# TODO: Does this work?
# self.module.params_dict().update(params)
return params
def main():
def f():
module = Distiller()
def init(data):
module.build(data)
return init, (module.train, module.loss, module.mlp)
init, (train, loss, pred) = hk.multi_transform(f)
train, loss, pred = jax.jit(train), jax.jit(loss), jax.jit(pred)
data = {'x': jnp.zeros((16, 64)), 'y': jnp.zeros((16, 8))}
rng = jax.random.PRNGKey(42)
params = init(rng, data)
for step in range(10):
params = train({}, rng, params, data)
print('Loss:', loss(params, rng, data))
print('Pred:', pred(params, rng, data).mean())
if __name__ == '__main__':
main()
Traceback (most recent call last):
File "/Users/danijar/temp/modules.py", line 96, in <module>
main()
File "/Users/danijar/temp/modules.py", line 84, in main
init, (train, loss, pred) = hk.multi_transform(f)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 229, in multi_transform
f = multi_transform_with_state(f)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 163, in multi_transform_with_state
output_treedef = jax.eval_shape(get_output_treedef).python_value
File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/_src/api.py", line 2967, in eval_shape
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped,
File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 524, in abstract_eval_fun
_, avals_out, _ = trace_to_jaxpr_dynamic(
File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/_src/profiler.py", line 206, in wrapper
return func(*args, **kwargs)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1828, in trace_to_jaxpr_dynamic
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(
File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1865, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/danijar/homebrew/lib/python3.9/site-packages/jax/linear_util.py", line 168, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 160, in get_output_treedef
apply_fns, _ = fns.apply(*fns.init(rng), rng)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/transform.py", line 335, in init_fn
f(*args, **kwargs)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/multi_transform.py", line 159, in <lambda>
fns = hk.transform_with_state(lambda: f()[1])
File "/Users/danijar/temp/modules.py", line 80, in f
module = Distiller()
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 120, in __call__
init(module, *args, **kwargs)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 416, in wrapped
out = f(*args, **kwargs)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 266, in run_interceptors
return bound_method(*args, **kwargs)
File "/Users/danijar/temp/modules.py", line 31, in __init__
self.target = Regressor()
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 120, in __call__
init(module, *args, **kwargs)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 416, in wrapped
out = f(*args, **kwargs)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 266, in run_interceptors
return bound_method(*args, **kwargs)
File "/Users/danijar/temp/modules.py", line 12, in __init__
self.opt = Optimizer(self.loss, self.mlp)
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 124, in __call__
module._auto_repr = utils.auto_repr(cls, *args, **kwargs) # pylint: disable=protected-access
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/utils.py", line 87, in auto_repr
single_line = cls.__name__ + "({})".format(", ".join(
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/utils.py", line 88, in <genexpr>
name + repr(value) for name, value in names_and_values))
File "/Users/danijar/homebrew/lib/python3.9/site-packages/haiku/_src/module.py", line 90, in <lambda>
lambda module: module._auto_repr) # pylint: disable=protected-access
AttributeError: 'Regressor' object has no attribute '_auto_repr'