enn
enn copied to clipboard
Problem occured in enn_demo.ipynb
Hi!
I was trying the enn_demo.ipynb
on google colab. Everything seems fine until I run this block of code.
# Train the experiment
experiment.train(FLAGS.num_batch)
and this error appears. Is there something wrong with the JAX version?
AttributeError Traceback (most recent call last)
[/usr/local/lib/python3.8/dist-packages/enn/networks/ensembles.py](https://kh9bbgsdon-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221220-060108-RC02_496713401#) in apply(params, states, inputs, index)
82 sub_states = jax.tree_map(particle_selector, states)
83 out, new_sub_states = model.apply(sub_params, sub_states, inputs)
---> 84 new_states = jax.tree_multimap(
85 lambda s, nss: s.at[index, ...].set(nss), states, new_sub_states)
86 return out, new_states
AttributeError: module 'jax' has no attribute 'tree_multimap'
Thanks, Adam