enn
enn copied to clipboard
Colab error with haiku/jax dependency
A following error appears on executing the colab notebook:
AttributeError Traceback (most recent call last)
<ipython-input-7-4863fdbb2553> in <module>()
13 num_ensemble=FLAGS.index_dim,
14 prior_scale=FLAGS.prior_scale,
---> 15 seed=FLAGS.seed,
16 )
17
4 frames
/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in __init__(self, output_sizes, dummy_input, num_ensemble, prior_scale, seed, w_init, b_init)
137 """Ensemble of MLPs with matched prior functions."""
138 mlp_priors = make_mlp_ensemble_prior_fns(
--> 139 output_sizes, dummy_input, num_ensemble, seed)
140 enn = priors.EnnWithAdditivePrior(
141 enn=MLPEnsembleEnn(
/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in make_mlp_ensemble_prior_fns(output_sizes, dummy_input, num_ensemble, seed, w_init, b_init)
90 return hk.Sequential(layers)(x)
91
---> 92 transformed = hk.without_apply_rng(hk.transform(net_fn))
93
94 prior_fns = []
/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform(f, apply_rng)
301 "Replace hk.transform(..., apply_rng=True) with hk.transform(...).")
302
--> 303 return without_state(transform_with_state(f))
304
305
/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform_with_state(f)
359 """
360 analytics.log_once("transform_with_state")
--> 361 check_not_jax_transformed(f)
362
363 unexpected_tracer_hint = (
/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)
306 def check_not_jax_transformed(f):
307 # TODO(tomhennigan): Consider `CompiledFunction = type(jax.jit(lambda: 0))`.
--> 308 if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)): # pytype: disable=name-error
309 raise ValueError("A common error with Haiku is to pass an already jit "
310 "(or pmap) decorated function into hk.transform (e.g. "
AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'