enn icon indicating copy to clipboard operation
enn copied to clipboard

Problem occured in enn_demo.ipynb

Open fazaghifari opened this issue 2 years ago • 0 comments

Hi!

I was trying the enn_demo.ipynb on google colab. Everything seems fine until I run this block of code.

# Train the experiment
experiment.train(FLAGS.num_batch)

and this error appears. Is there something wrong with the JAX version?

AttributeError                            Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/enn/networks/ensembles.py](https://kh9bbgsdon-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221220-060108-RC02_496713401#) in apply(params, states, inputs, index)
     82       sub_states = jax.tree_map(particle_selector, states)
     83       out, new_sub_states = model.apply(sub_params, sub_states, inputs)
---> 84       new_states = jax.tree_multimap(
     85           lambda s, nss: s.at[index, ...].set(nss), states, new_sub_states)
     86       return out, new_states

AttributeError: module 'jax' has no attribute 'tree_multimap'

Thanks, Adam

fazaghifari avatar Dec 27 '22 09:12 fazaghifari