AttributeError: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].
Report
Hi, Thank you for developing the tool.
When I try to import the library I got the following error :
AttributeError: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].
the whole error is:
{
"name": "AttributeError",
"message": "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
"stack": "---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[7], line 1
----> 1 import pertpy as pt
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/pertpy/__init__.py:21
19 from . import plot as pl
20 from . import preprocessing as pp
---> 21 from . import tools as tl
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/pertpy/tools/__init__.py:42
32 from pertpy.tools._perturbation_space._discriminator_classifiers import (
33 LRClassifierSpace,
34 MLPClassifierSpace,
35 )
36 from pertpy.tools._perturbation_space._simple import (
37 CentroidSpace,
38 DBSCANSpace,
39 KMeansSpace,
40 PseudobulkSpace,
41 )
---> 42 from pertpy.tools._scgen import Scgen
44 CODA_EXTRAS = [\"toytree\", \"arviz\", \"ete3\"] # also pyqt5 technically
45 Sccoda = lazy_import(\"pertpy.tools._coda._sccoda\", \"Sccoda\", CODA_EXTRAS)
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/pertpy/tools/_scgen/__init__.py:1
----> 1 from pertpy.tools._scgen._scgen import Scgen
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/pertpy/tools/_scgen/_scgen.py:15
13 from lamin_utils import logger
14 from scipy import stats
---> 15 from scvi import REGISTRY_KEYS
16 from scvi.data import AnnDataManager
17 from scvi.data.fields import CategoricalObsField, LayerField
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/scvi/__init__.py:11
8 from ._settings import settings
10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import autotune, data, model, external, utils, criticism
13 from importlib.metadata import version
15 package_name = \"scvi-tools\"
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/scvi/autotune/__init__.py:1
----> 1 from ._manager import TuneAnalysis, TunerManager
2 from ._tuner import ModelTuner
3 from ._types import Tunable, TunableMixin
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/scvi/autotune/_manager.py:11
9 import lightning.pytorch as pl
10 import rich
---> 11 from chex import dataclass
13 try:
14 import ray
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/chex/__init__.py:17
1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the \"License\");
(...)
13 # limitations under the License.
14 # ==============================================================================
15 \"\"\"Chex: Testing made fun, in JAX!\"\"\"
---> 17 from chex._src.asserts import assert_axis_dimension
18 from chex._src.asserts import assert_axis_dimension_comparator
19 from chex._src.asserts import assert_axis_dimension_gt
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/chex/_src/asserts.py:26
23 import unittest
24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
27 from chex._src import pytypes
28 import jax
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/chex/_src/asserts_internal.py:34
31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
33 from absl import logging
---> 34 from chex._src import pytypes
35 import jax
36 from jax.experimental import checkify
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/chex/_src/pytypes.py:53
51 Scalar = Union[float, int]
52 Numeric = Union[Array, Scalar]
---> 53 Shape = jax.core.Shape
54 PRNGKey = jax.random.KeyArray
55 PyTreeDef = jax.tree_util.PyTreeDef
File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/jax/_src/deprecations.py:52, in deprecation_getattr.<locals>.getattr(name)
50 message, fn = deprecations[name]
51 if fn is None: # Is the deprecation accelerated?
---> 52 raise AttributeError(message)
53 warnings.warn(message, DeprecationWarning, stacklevel=2)
54 return fn
AttributeError: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any]."
}
Version information
No response
Thanks for the bug report! Could you please provide us with the version information?
Version: 0.6.0
But thanks the problem is resolved. I figured out that my Python version is 3.12 but I installed the tool with pip 3.9.
reinstalling pertpy using python -m pip install pertpy has solved the problem.
I encountered the same issue when installing pertpy today.
Python: 3.9.16
pertpy: 0.6.0
chex: 0.1.7
Downgrading jax and jaxlib to 0.4.23, lineax to 0.0.4, and ott-jax to 0.4.6 resolved the issue for me.