scvi-tools
scvi-tools copied to clipboard
Make JAX an optional dependency
I would find this very helpful. I keep getting this error when trying to import scvi-tools
ImportError: cannot import name 'ShapedArray' from 'jax' (/users/rng/mambaforge/envs/scai-v4/lib/python3.10/site-packages/jax/__init__.py)
I try making jax=0.4.13 but its keeps getting replaced with v0.4.20 because it is required by dependencies in pertpy and also scvi-tools is a requirement for pertpy. I can't figure out an environment where scvi-tools, scarches, and pertpy can co-exist. Would appreciate it if anyone can share working combination of package versions, especially for these packages:
python
pytorch
torchvision
torchaudio
pytorch-cuda
jax
jaxlib
chex
flax
scvi-tools
scarches
pertpy
Thanks!
Hi, sorry that you're running into this issue - we're working on making this available in #2318 as part of our next scvi-tools release (v1.1). However, I'm unable to reproduce the ImportError
you're getting with JAX 0.4.20. Could you post the full traceback you're seeing? Thanks.
I am running the same issue. My Jax version is 0.4.23
ImportError Traceback (most recent call last) Cell In[8], line 1 ----> 1 import scvi
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/init.py:10 7 from ._settings import settings 9 # this import needs to come after prior imports to prevent circular import ---> 10 from . import data, model, external, utils 12 # https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094 13 # https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302 14 try:
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/model/init.py:2 1 from . import utils ----> 2 from ._amortizedlda import AmortizedLDA 3 from ._autozi import AUTOZI 4 from ._condscvi import CondSCVI
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/model/_amortizedlda.py:14 12 from scvi.data import AnnDataManager 13 from scvi.data.fields import LayerField ---> 14 from scvi.module import AmortizedLDAPyroModule 15 from scvi.utils import setup_anndata_dsp 17 from .base import BaseModelClass, PyroSviTrainMixin
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/module/init.py:4 2 from ._autozivae import AutoZIVAE 3 from ._classifier import Classifier ----> 4 from ._jaxvae import JaxVAE 5 from ._mrdeconv import MRDeconv 6 from ._multivae import MULTIVAE
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/scvi/module/_jaxvae.py:7 5 import numpy as np 6 import numpyro.distributions as dist ----> 7 from flax import linen as nn 8 from flax.linen.initializers import variance_scaling 10 from scvi import REGISTRY_KEYS
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/flax/init.py:20 18 from . import core 19 from . import jax_utils ---> 20 from . import linen 21 from . import serialization 22 from . import traverse_util
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/flax/linen/init.py:47 18 # pylint: disable=g-multiple-import 19 # re-export commonly used modules and functions 20 from .activation import ( 21 PReLU as PReLU, 22 celu as celu, (...) 45 tanh as tanh 46 ) ---> 47 from .attention import ( 48 MultiHeadDotProductAttention as MultiHeadDotProductAttention, 49 SelfAttention as SelfAttention, 50 combine_masks as combine_masks, 51 dot_product_attention as dot_product_attention, 52 dot_product_attention_weights as dot_product_attention_weights, 53 make_attention_mask as make_attention_mask, 54 make_causal_mask as make_causal_mask 55 ) 56 from .combinators import Sequential as Sequential 57 from ..core import ( 58 DenyList as DenyList, 59 FrozenDict as FrozenDict, 60 broadcast as broadcast 61 )
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/flax/linen/attention.py:22 19 from flax.linen.dtypes import promote_dtype 21 from flax.linen.initializers import zeros ---> 22 from flax.linen.linear import default_kernel_init 23 from flax.linen.linear import DenseGeneral 24 from flax.linen.linear import PrecisionLike
File ~/mambaforge/mambaforge/envs/scvi/lib/python3.9/site-packages/flax/linen/linear.py:30 28 from jax import eval_shape 29 from jax import lax ---> 30 from jax import ShapedArray 31 import jax.numpy as jnp 32 import numpy as np
ImportError: cannot import name 'ShapedArray' from 'jax'
It was fixed by install flax see #2216