ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)
/tmp/ipykernel_34/2874194604.py:15: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display from IPython.core.display import display, HTML
ImportError Traceback (most recent call last) Cell In[27], line 20 18 # Import model definition from big_vision 19 from big_vision.models.proj.paligemma import paligemma ---> 20 from big_vision.trainers.proj.paligemma import predict_fns 22 # Import big vision utilities 23 import big_vision.datasets.jsonl
File /kaggle/working/big_vision_repo/big_vision/trainers/proj/paligemma/predict_fns.py:20 17 import functools 19 from big_vision.pp import registry ---> 20 import big_vision.utils as u 21 import einops 22 import jax
File /kaggle/working/big_vision_repo/big_vision/utils.py:38 36 import flax.jax_utils as flax_utils 37 import jax ---> 38 from jax.experimental.array_serialization import serialization as array_serial 39 import jax.numpy as jnp 40 import ml_collections as mlc
File /opt/conda/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py:36 34 from jax._src import sharding 35 from jax._src import sharding_impls ---> 36 from jax._src.layout import Layout, DeviceLocalLayout as DLL 37 from jax._src import typing 38 from jax._src import util
ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)
I have same error. what can we do?
Hi there. I don't really understand this issue, it doesn't look like you're importing chex here. Could you provide a step-by-step reproduction?