trax
trax copied to clipboard
ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'
Description
ImportError thrown after importing libraries ...
Environment information
trax 1.4.1
OS: Ubuntu
$ pip freeze | grep trax
trax 1.4.1
$ pip freeze | grep tensor
mesh-tensorflow==0.1.21
tensor2tensor==1.15.7
tensorboard==2.11.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.11.0
tensorflow-addons==0.18.0
tensorflow-datasets==4.7.0
tensorflow-estimator==2.11.0
tensorflow-gan==2.1.0
tensorflow-hub==0.12.0
tensorflow-io-gcs-filesystem==0.28.0
tensorflow-metadata==1.12.0
tensorflow-probability==0.7.0
tensorflow-text==2.11.0
tensorstore==0.1.28
$ pip freeze | grep jax
jax==0.3.25
jaxlib==0.3.25
$ python -V
Python 3.8.10
For bugs: reproduction and error logs
# Steps to reproduce:
!pip install -q -U trax
import numpy as np # regular ol' numpy
from trax import fastmath
from trax import layers as tl
from trax import shapes
from trax.fastmath import numpy as jnp # For use in defining new layer types.
from trax.shapes import ShapeDtype
from trax.shapes import signature
# Error logs:
---------------------------------------------------------------------------
ImportError Traceback (most recent call last)
Cell In[22], line 3
1 import numpy as np # regular ol' numpy
----> 3 from trax import fastmath
4 from trax import layers as tl
5 from trax import shapes
File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/__init__.py:18
1 # coding=utf-8
2 # Copyright 2021 The Trax Authors.
3 #
(...)
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
16 """Trax top level import."""
---> 18 from trax import data
19 from trax import fastmath
20 from trax import layers
File ~/NovaceneAI/trax_projects/.venv/lib/python3.8/site-packages/trax/data/__init__.py:70
67 from trax.data.inputs import UnBatch
68 from trax.data.inputs import UniformlySeek
---> 70 from trax.data.tf_inputs import add_eos_to_output_features
71 from trax.data.tf_inputs import BertGlueEvalStream
...
35 from trax.layers.attention import SplitIntoHeads
38 # Layers are always CamelCase, but functions in general are snake_case
39 # pylint: disable=invalid-name
ImportError: cannot import name 'MergeHeads' from 'trax.layers.attention'