pertpy icon indicating copy to clipboard operation
pertpy copied to clipboard

AttributeError: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].

Open muntajihad opened this issue 1 year ago • 3 comments

Report

Hi, Thank you for developing the tool.

When I try to import the library I got the following error : AttributeError: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].

the whole error is:

{
	"name": "AttributeError",
	"message": "jax.core.Shape is deprecated. Use Shape = Sequence[int | Any].",
	"stack": "---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[7], line 1
----> 1 import pertpy as pt

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/pertpy/__init__.py:21
     19 from . import plot as pl
     20 from . import preprocessing as pp
---> 21 from . import tools as tl

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/pertpy/tools/__init__.py:42
     32 from pertpy.tools._perturbation_space._discriminator_classifiers import (
     33     LRClassifierSpace,
     34     MLPClassifierSpace,
     35 )
     36 from pertpy.tools._perturbation_space._simple import (
     37     CentroidSpace,
     38     DBSCANSpace,
     39     KMeansSpace,
     40     PseudobulkSpace,
     41 )
---> 42 from pertpy.tools._scgen import Scgen
     44 CODA_EXTRAS = [\"toytree\", \"arviz\", \"ete3\"]  # also pyqt5 technically
     45 Sccoda = lazy_import(\"pertpy.tools._coda._sccoda\", \"Sccoda\", CODA_EXTRAS)

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/pertpy/tools/_scgen/__init__.py:1
----> 1 from pertpy.tools._scgen._scgen import Scgen

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/pertpy/tools/_scgen/_scgen.py:15
     13 from lamin_utils import logger
     14 from scipy import stats
---> 15 from scvi import REGISTRY_KEYS
     16 from scvi.data import AnnDataManager
     17 from scvi.data.fields import CategoricalObsField, LayerField

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/scvi/__init__.py:11
      8 from ._settings import settings
     10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import autotune, data, model, external, utils, criticism
     13 from importlib.metadata import version
     15 package_name = \"scvi-tools\"

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/scvi/autotune/__init__.py:1
----> 1 from ._manager import TuneAnalysis, TunerManager
      2 from ._tuner import ModelTuner
      3 from ._types import Tunable, TunableMixin

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/scvi/autotune/_manager.py:11
      9 import lightning.pytorch as pl
     10 import rich
---> 11 from chex import dataclass
     13 try:
     14     import ray

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/chex/__init__.py:17
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the \"License\");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 \"\"\"Chex: Testing made fun, in JAX!\"\"\"
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/chex/_src/asserts.py:26
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/chex/_src/asserts_internal.py:34
     31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
     33 from absl import logging
---> 34 from chex._src import pytypes
     35 import jax
     36 from jax.experimental import checkify

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/chex/_src/pytypes.py:53
     51 Scalar = Union[float, int]
     52 Numeric = Union[Array, Scalar]
---> 53 Shape = jax.core.Shape
     54 PRNGKey = jax.random.KeyArray
     55 PyTreeDef = jax.tree_util.PyTreeDef

File ~/miniconda3/envs/scanpy/lib/python3.11/site-packages/jax/_src/deprecations.py:52, in deprecation_getattr.<locals>.getattr(name)
     50 message, fn = deprecations[name]
     51 if fn is None:  # Is the deprecation accelerated?
---> 52   raise AttributeError(message)
     53 warnings.warn(message, DeprecationWarning, stacklevel=2)
     54 return fn

AttributeError: jax.core.Shape is deprecated. Use Shape = Sequence[int | Any]."
}

Version information

No response

muntajihad avatar Sep 26 '24 16:09 muntajihad

Thanks for the bug report! Could you please provide us with the version information?

Zethson avatar Sep 28 '24 13:09 Zethson

Version: 0.6.0 But thanks the problem is resolved. I figured out that my Python version is 3.12 but I installed the tool with pip 3.9. reinstalling pertpy using python -m pip install pertpy has solved the problem.

muntajihad avatar Sep 29 '24 13:09 muntajihad

I encountered the same issue when installing pertpy today.

Python: 3.9.16 pertpy: 0.6.0 chex: 0.1.7

Downgrading jax and jaxlib to 0.4.23, lineax to 0.0.4, and ott-jax to 0.4.6 resolved the issue for me.

merelkuijs avatar Sep 30 '24 01:09 merelkuijs