orbax icon indicating copy to clipboard operation
orbax copied to clipboard

Installation error when installing T5x

Open jntdst opened this issue 1 year ago • 1 comments

ModuleNotFoundError Traceback (most recent call last) in <cell line: 1>() ----> 1 import t5x 2 from t5x import partitioning 3 from t5x import train_state as train_state_lib 4 from t5x import utils 5 from t5x.examples.t5 import network

5 frames /content/t5x/t5x/init.py in 15 """Import API modules.""" 16 ---> 17 import t5x.adafactor 18 import t5x.checkpoints 19 import t5x.decoding

/content/t5x/t5x/adafactor.py in 63 import jax.numpy as jnp 64 import numpy as np ---> 65 from t5x import utils 66 from t5x.optimizers import OptimizerDef 67 from t5x.optimizers import OptimizerState

/content/t5x/t5x/utils.py in 44 import jax.numpy as jnp 45 import numpy as np ---> 46 import orbax.checkpoint 47 import seqio 48 from t5x import checkpoints

/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/init.py in 17 import functools 18 ---> 19 from orbax.checkpoint import checkpoint_utils 20 from orbax.checkpoint import lazy_utils 21 from orbax.checkpoint import test_utils

/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/checkpoint_utils.py in 23 from jax.sharding import Mesh 24 import numpy as np ---> 25 from orbax.checkpoint import type_handlers 26 from orbax.checkpoint import utils 27

/usr/local/lib/python3.10/dist-packages/orbax/checkpoint/type_handlers.py in 22 from etils import epath 23 import jax ---> 24 from jax.experimental.gda_serialization import serialization 25 from jax.experimental.gda_serialization.serialization import get_tensorstore_spec 26 import jax.numpy as jnp

ModuleNotFoundError: No module named 'jax.experimental.gda_serialization'

jntdst avatar Feb 17 '24 08:02 jntdst

Make sure you update your T5X and orbax-checkpoint packages. T5X has a dependency on Orbax at head, which no longer depends on gda_serialization.

cpgaffney1 avatar Feb 20 '24 19:02 cpgaffney1