enn icon indicating copy to clipboard operation
enn copied to clipboard

Colab error with haiku/jax dependency

Open hstojic opened this issue 3 years ago • 0 comments

A following error appears on executing the colab notebook:

AttributeError                            Traceback (most recent call last)
<ipython-input-7-4863fdbb2553> in <module>()
     13     num_ensemble=FLAGS.index_dim,
     14     prior_scale=FLAGS.prior_scale,
---> 15     seed=FLAGS.seed,
     16 )
     17 

4 frames
/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in __init__(self, output_sizes, dummy_input, num_ensemble, prior_scale, seed, w_init, b_init)
    137     """Ensemble of MLPs with matched prior functions."""
    138     mlp_priors = make_mlp_ensemble_prior_fns(
--> 139         output_sizes, dummy_input, num_ensemble, seed)
    140     enn = priors.EnnWithAdditivePrior(
    141         enn=MLPEnsembleEnn(

/usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in make_mlp_ensemble_prior_fns(output_sizes, dummy_input, num_ensemble, seed, w_init, b_init)
     90     return hk.Sequential(layers)(x)
     91 
---> 92   transformed = hk.without_apply_rng(hk.transform(net_fn))
     93 
     94   prior_fns = []

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform(f, apply_rng)
    301         "Replace hk.transform(..., apply_rng=True) with hk.transform(...).")
    302 
--> 303   return without_state(transform_with_state(f))
    304 
    305 

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform_with_state(f)
    359   """
    360   analytics.log_once("transform_with_state")
--> 361   check_not_jax_transformed(f)
    362 
    363   unexpected_tracer_hint = (

/usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)
    306 def check_not_jax_transformed(f):
    307   # TODO(tomhennigan): Consider `CompiledFunction = type(jax.jit(lambda: 0))`.
--> 308   if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)):  # pytype: disable=name-error
    309     raise ValueError("A common error with Haiku is to pass an already jit "
    310                      "(or pmap) decorated function into hk.transform (e.g. "

AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'

hstojic avatar Dec 03 '21 14:12 hstojic