trax
trax copied to clipboard
AttributeError: module 'jax.ops' has no attribute 'index_add'
Description
I am trying to do something basic in my code:
import numpy as np # regular ol' numpy
from trax import layers as tl # core building block
from trax import shapes # data signatures: dimensionality and type
from trax import fastmath # uses jax, offers numpy on steroids
Upon import it errors out doing the basics here. What am I doing wrong? Should I be pinning a different version of the code?
Environment information
OS: Cento lsb_release LSB Version: :core-4.1-amd64:core-4.1-ia32:core-4.1-noarch:cxx-4.1-amd64:cxx-4.1-ia32:cxx-4.1-noarch:desktop-4.1-amd64:desktop-4.1-ia32:desktop-4.1-noarch:languages-4.1-amd64:languages-4.1-noarch:printing-4.1-amd64:printing-4.1-noarch
$ pip freeze | grep trax trax==1.3.9
$ pip freeze | grep tensor mesh-tensorflow==0.1.21 tensorboard==2.11.2 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorflow==2.11.0 tensorflow-datasets==4.8.2 tensorflow-estimator==2.11.0 tensorflow-hub==0.12.0 tensorflow-io-gcs-filesystem==0.30.0 tensorflow-metadata==1.12.0 tensorflow-text==2.11.0
$ pip freeze | grep jax jax==0.4.4 jaxlib==0.4.4
$ python -V Python 3.9.16
### For bugs: reproduction and error logs
# Error logs:
...
1 # coding=utf-8
2 # Copyright 2021 The Trax Authors.
3 #
(...)
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
16 """Trax top level import."""
---> 18 from trax import data
19 from trax import fastmath
20 from trax import layers
File ./ds_work/miniconda3/envs/coursera-nlp/lib/python3.9/site-packages/trax/data/__init__.py:36, in <module>
16 """Functions and classes for obtaining and preprocesing data.
17
18 The ``trax.data`` module presents a flattened (no subpackages) public API.
(...)
...
217 'vjp': jax.vjp,
218 'vmap': jax.vmap,
219 }
AttributeError: module 'jax.ops' has no attribute 'index_add'
downgrade jax to 0.2.21 jax.ops.index_add is deprecated in 0.2.22 https://gitee.com/mirrors/JAX/blob/main/CHANGELOG.md#jax-0222-oct-12-2021