`assign_mixture_model()` is really slow, even using JAX with CUDA support
Please make sure these conditions are met
- [x] I have checked that this issue has not already been reported.
- [x] I have confirmed this bug exists on the latest version of pertpy.
- [x] (optional) I have confirmed this bug exists on the main branch.
Report
Hi all,
I am trying to do gRNA assignment with pertpy's assign_mixture_model(), but it seems extremely slow and so I would be very grateful for any advice! I am running this on an interactive GPU node on Wharton's HPC where JAX is installed with CUDA support. I am using a subset of the Gasperini data with the first 2500 gRNAs and the first 50k cells. It takes hours to run on this dataset, which seems weird to me. I get a constant stream of JAX/JIT compilation notices (when using JAX_LOG_COMPILES=1), which don't seem to slow down as it runs. I'd have expected a flurry of them to begin with, but it seems non-stop, so I'm wondering if this reveals unintended behavior. Thanks in advance for any help here!
Below are the exact steps I'm doing to recreate this.
First, I'm starting a fresh GPU session (via $ qlogin -q gpu.q), confirming that I do indeed have a GPU, and determining which version of CUDA I have.
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Wed_Apr_17_19:19:55_PDT_2024
Cuda compilation tools, release 12.5, V12.5.40
Build cuda_12.5.r12.5/compiler.34177558_0
$ nvidia-smi
Wed Nov 12 13:40:59 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02 Driver Version: 555.42.02 CUDA Version: 12.5 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA L40S Off | 00000000:30:00.0 Off | 0 |
| N/A 22C P8 20W / 350W | 4MiB / 46068MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Next, I'm making my conda environment and setting a few shell variables:
conda create -n pertpy-jax-issue python=3.12 -y
conda activate pertpy-jax-issue
python -m pip install -U pip setuptools wheel
pip install -U pertpy scanpy anndata h5py filelock
pip install -U "jax[cuda12]"
unset LD_LIBRARY_PATH # prevents use of system CUDA, which lacks cuSPARSE
export JAX_PLATFORM_NAME=gpu
export JAX_LOG_COMPILES=1 # give detailed logs
Finally, this is the code I run to actually do gRNA assignment, along with some checks to verify that everything seems as expected.
import time
import sys
import jax, jax.numpy as jnp
import pertpy as pt
import scanpy as sc
# This is a subset of the (high MOI) Gasperini data,
# with 2500 gRNAs and 50,000 cells
dataset_name = "grna_matrix.h5ad"
# verify python and pertpy versions
print("python version:", sys.version) # shows `3.12.12`
print("pertpy version:", pt.__version__) # shows `1.0.3`
# Checking that we see the GPU
print("JAX backend:", jax.default_backend()) # shows `JAX backend: gpu`
print("JAX devices:", jax.devices()) # shows `JAX devices: [CudaDevice(id=0)]`
# quick warmup to force GPU work, just in case this helps
x = jnp.ones((4096, 4096), dtype=jnp.float32)
t0 = time.time()
(x @ x.T).block_until_ready()
print(f"Warmup matmul completed in {time.time() - t0:.2f}s")
# Loading the data and running gRNA assignment
adata = sc.read_h5ad(dataset_name)
ga = pt.pp.GuideAssignment()
print("Running mixture-model assignment...")
t0 = time.time()
ga.assign_mixture_model(
adata,
max_assignments_per_cell=100, # I tried this with 40 too
show_progress=True
)
dt = time.time() - t0
print(f"Finished in {dt:.1f}s")
Here is an example of the stream of logging output that I get. This is after letting it run for one hour, so I don't think this is just things warming up! And it is only ~ 20% done after an hour, on a dataset with just 2500 gRNAs and 50,000 cells.
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18% 5:14:50WARNING:2025-11-12 15:16:02,874:jax._src.dispatch:222: Finished XLA compilation of jit(div) in 0.038060904 sec
WARNING:2025-11-12 15:16:02,877:jax._src.interpreters.pxla:1960: Compiling jit(_where) with global shapes and types (ShapedArray(bool[2,366]), ShapedArray(float64[2,366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:02,885:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(_where) in 0.007384777 sec
WARNING:2025-11-12 15:16:02,922:jax._src.dispatch:222: Finished XLA compilation of jit(_where) in 0.036580563 sec
WARNING:2025-11-12 15:16:02,926:jax._src.interpreters.pxla:1960: Compiling jit(neg) with global shapes and types (ShapedArray(float64[366]),). Argument mapping: (UnspecifiedValue,).
WARNING:2025-11-12 15:16:02,933:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(neg) in 0.006520033 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18% 5:14:50WARNING:2025-11-12 15:16:02,964:jax._src.dispatch:222: Finished XLA compilation of jit(neg) in 0.030367374 sec
WARNING:2025-11-12 15:16:02,969:jax._src.interpreters.pxla:1960: Compiling jit(reduce_sum) with global shapes and types (ShapedArray(float64[366]),). Argument mapping: (UnspecifiedValue,).
WARNING:2025-11-12 15:16:02,977:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(reduce_sum) in 0.007550478 sec
WARNING:2025-11-12 15:16:03,021:jax._src.dispatch:222: Finished XLA compilation of jit(reduce_sum) in 0.043447256 sec
WARNING:2025-11-12 15:16:03,024:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[]), ShapedArray(float64[366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,032:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.007074356 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18% 5:14:50WARNING:2025-11-12 15:16:03,066:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.034173489 sec
WARNING:2025-11-12 15:16:03,070:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,077:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.006629467 sec
WARNING:2025-11-12 15:16:03,108:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.031239033 sec
WARNING:2025-11-12 15:16:03,112:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,119:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.006843328 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 18% 5:14:50WARNING:2025-11-12 15:16:03,152:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.032195091 sec
WARNING:2025-11-12 15:16:03,156:jax._src.interpreters.pxla:1960: Compiling jit(div) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,163:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(div) in 0.006853342 sec
WARNING:2025-11-12 15:16:03,201:jax._src.dispatch:222: Finished XLA compilation of jit(div) in 0.037773132 sec
WARNING:2025-11-12 15:16:03,207:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366,2]), ShapedArray(float64[366,2])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,216:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.007925987 sec
Thanks again for any insights!
-Louis
Versions
| Package | Version |
| ------- | ------- |
| pertpy | 1.0.3 |
| Dependency | Version |
| ------------------- | ------------- |
| numpy | 2.3.4 |
| multipledispatch | 1.0.0 (0.6.0) |
| toolz | 1.1.0 |
| h5py | 3.15.1 |
| Pygments | 2.19.2 |
| scipy | 1.16.3 |
| llvmlite | 0.45.1 |
| idna | 3.11 |
| anndata | 0.12.6 |
| jax-cuda12-plugin | 0.8.0 |
| pyparsing | 3.2.5 |
| ply | 3.11 |
| blitzgsea | 1.3.54 |
| cycler | 0.12.1 |
| kiwisolver | 1.4.9 |
| joblib | 1.5.2 |
| natsort | 8.4.0 |
| jax-cuda12-pjrt | 0.8.0 |
| importlib_resources | 6.5.2 |
| certifi | 2025.11.12 |
| msgpack | 1.1.2 |
| numpyro | 0.19.0 |
| equinox | 0.13.2 |
| statsmodels | 0.14.5 |
| donfig | 0.8.1.post1 |
| xarray | 2025.10.1 |
| opt_einsum | 3.4.0 |
| fsspec | 2025.10.0 |
| fast-array-utils | 1.3 |
| setuptools | 80.9.0 |
| requests | 2.32.5 |
| charset-normalizer | 3.4.4 |
| pillow | 12.0.0 |
| filelock | 3.20.0 |
| python-dateutil | 2.9.0.post0 |
| zarr | 3.1.3 |
| six | 1.17.0 |
| psutil | 7.1.3 |
| jaxlib | 0.8.0 |
| legacy-api-wrap | 1.5 |
| urllib3 | 2.5.0 |
| absl-py | 2.3.1 |
| rich | 14.2.0 |
| pandas | 2.3.3 |
| flax | 0.12.0 |
| pyarrow | 22.0.0 |
| simplejson | 3.20.2 |
| jax | 0.8.0 |
| ott-jax | 0.6.0 |
| typing_extensions | 4.15.0 |
| scikit-learn | 1.7.2 |
| optax | 0.2.6 |
| pyomo | 6.9.5 |
| crc32c | 2.8 |
| lineax | 0.0.8 |
| adjustText | 1.3.0 |
| PubChemPy | 1.0.5 |
| sparsecca | 0.3.1 |
| mudata | 0.3.2 |
| pytz | 2025.2 |
| PyYAML | 6.0.3 |
| scanpy | 1.11.5 |
| packaging | 25.0 |
| wadler_lindig | 0.1.7 |
| threadpoolctl | 3.6.0 |
| matplotlib | 3.10.7 |
| seaborn | 0.13.2 |
| lamin_utils | 0.15.0 |
| numba | 0.62.1 |
| chex | 0.1.91 |
| etils | 1.13.0 |
| jaxopt | 0.8.5 |
| numcodecs | 0.16.3 |
| jaxtyping | 0.3.3 |
| ml_dtypes | 0.5.3 |
| scikit-misc | 0.5.2 |
| tqdm | 4.67.1 |
| session-info2 | 0.2.3 |
| patsy | 1.0.2 |
| mpmath | 1.3.0 |
| Component | Info |
| --------- | --------------------------------------------------------------------------------- |
| Python | 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] |
| OS | Linux-4.18.0-553.5.1.el8_10.x86_64-x86_64-with-glibc2.28 |
| Updated | 2025-11-12 20:27
Thanks! We haven't really performance tested it on GPUs but something seems to be off. It might take a week or two for us to get to it as the scverse conference is also coming up.
Could you please just confirm that other GPU accelerated Jax algorithms (outside of pertpy) are as performant as expected on your HPC?
Thank you for your rapid response! I will update shortly as I confirm that it's not just a cluster issue!
Hi @Zethson, I believe that the cluster I'm on is indeed using the GPU to speed JAX up! I installed jax[cuda12] into a fresh conda environment, and here's the crux of the test that I ran in that environment to confirm this:
import jax
import jax.numpy as jnp
from jax import random
n = 10000 # matrix dimension
key = random.PRNGKey(0) # using this to get the same matrices for CPU and GPU
a = random.normal(key, (n, n), dtype=jnp.float32)
b = random.normal(key, (n, n), dtype=jnp.float32)
@jax.jit
def matmul(x, y):
return x @ y
# Warmup to trigger JIT compilation (not timed)
c = matmul(a, b).block_until_ready()
# Timed repetitions
t0 = time.time()
for i in range(reps):
c = matmul(a, b).block_until_ready()
t1 = time.time()
total = t1 - t0
I ran this with JAX_PLATFORM_NAME=gpu and JAX_PLATFORM_NAME=cpu on the GPU queue. When JAX_PLATFORM_NAME=gpu, it took 0.10 sec for 5 repetitions of this matrix multiplication. When JAX_PLATFORM_NAME=cpu it took 24.26 sec for 5 runs. That seems like confirmation to me that the GPU was used, at least for this test? I'm happy to run additional tests if that would provide more insight!
Best, Louis
Got it, thanks! We'll try to look into it soon. Please remain patient for a bit.
I'll look into this in about 14 days.
Great, thank you!