pertpy icon indicating copy to clipboard operation
pertpy copied to clipboard

`assign_mixture_model()` is really slow, even using JAX with CUDA support

Open jdeu1023 opened this issue 2 months ago • 4 comments

Please make sure these conditions are met

  • [x] I have checked that this issue has not already been reported.
  • [x] I have confirmed this bug exists on the latest version of pertpy.
  • [x] (optional) I have confirmed this bug exists on the main branch.

Report

Hi all,

I am trying to do gRNA assignment with pertpy's assign_mixture_model(), but it seems extremely slow and so I would be very grateful for any advice! I am running this on an interactive GPU node on Wharton's HPC where JAX is installed with CUDA support. I am using a subset of the Gasperini data with the first 2500 gRNAs and the first 50k cells. It takes hours to run on this dataset, which seems weird to me. I get a constant stream of JAX/JIT compilation notices (when using JAX_LOG_COMPILES=1), which don't seem to slow down as it runs. I'd have expected a flurry of them to begin with, but it seems non-stop, so I'm wondering if this reveals unintended behavior. Thanks in advance for any help here!

Below are the exact steps I'm doing to recreate this.

First, I'm starting a fresh GPU session (via $ qlogin -q gpu.q), confirming that I do indeed have a GPU, and determining which version of CUDA I have.

$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Wed_Apr_17_19:19:55_PDT_2024
Cuda compilation tools, release 12.5, V12.5.40
Build cuda_12.5.r12.5/compiler.34177558_0

$ nvidia-smi
Wed Nov 12 13:40:59 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.42.02              Driver Version: 555.42.02      CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA L40S                    Off |   00000000:30:00.0 Off |                    0 |
| N/A   22C    P8             20W /  350W |       4MiB /  46068MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Next, I'm making my conda environment and setting a few shell variables:

conda create -n pertpy-jax-issue python=3.12 -y
conda activate pertpy-jax-issue

python -m pip install -U pip setuptools wheel
pip install -U pertpy scanpy anndata h5py filelock

pip install -U "jax[cuda12]"

unset LD_LIBRARY_PATH # prevents use of system CUDA, which lacks cuSPARSE

export JAX_PLATFORM_NAME=gpu
export JAX_LOG_COMPILES=1 # give detailed logs

Finally, this is the code I run to actually do gRNA assignment, along with some checks to verify that everything seems as expected.

import time
import sys

import jax, jax.numpy as jnp
import pertpy as pt
import scanpy as sc

# This is a subset of the (high MOI) Gasperini data,
# with 2500 gRNAs and 50,000 cells
dataset_name = "grna_matrix.h5ad"

# verify python and pertpy versions
print("python version:", sys.version) # shows `3.12.12`
print("pertpy version:", pt.__version__) # shows `1.0.3`

# Checking that we see the GPU
print("JAX backend:", jax.default_backend()) # shows `JAX backend: gpu`
print("JAX devices:", jax.devices()) # shows `JAX devices: [CudaDevice(id=0)]`
# quick warmup to force GPU work, just in case this helps
x = jnp.ones((4096, 4096), dtype=jnp.float32)
t0 = time.time()
(x @ x.T).block_until_ready()
print(f"Warmup matmul completed in {time.time() - t0:.2f}s")

# Loading the data and running gRNA assignment
adata = sc.read_h5ad(dataset_name)

ga = pt.pp.GuideAssignment()
print("Running mixture-model assignment...")
t0 = time.time()
ga.assign_mixture_model(
   adata,
   max_assignments_per_cell=100, # I tried this with 40 too
   show_progress=True
)
dt = time.time() - t0
print(f"Finished in {dt:.1f}s")

Here is an example of the stream of logging output that I get. This is after letting it run for one hour, so I don't think this is just things warming up! And it is only ~ 20% done after an hour, on a dataset with just 2500 gRNAs and 50,000 cells.

Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  18% 5:14:50WARNING:2025-11-12 15:16:02,874:jax._src.dispatch:222: Finished XLA compilation of jit(div) in 0.038060904 sec
WARNING:2025-11-12 15:16:02,877:jax._src.interpreters.pxla:1960: Compiling jit(_where) with global shapes and types (ShapedArray(bool[2,366]), ShapedArray(float64[2,366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:02,885:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(_where) in 0.007384777 sec
WARNING:2025-11-12 15:16:02,922:jax._src.dispatch:222: Finished XLA compilation of jit(_where) in 0.036580563 sec
WARNING:2025-11-12 15:16:02,926:jax._src.interpreters.pxla:1960: Compiling jit(neg) with global shapes and types (ShapedArray(float64[366]),). Argument mapping: (UnspecifiedValue,).
WARNING:2025-11-12 15:16:02,933:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(neg) in 0.006520033 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  18% 5:14:50WARNING:2025-11-12 15:16:02,964:jax._src.dispatch:222: Finished XLA compilation of jit(neg) in 0.030367374 sec
WARNING:2025-11-12 15:16:02,969:jax._src.interpreters.pxla:1960: Compiling jit(reduce_sum) with global shapes and types (ShapedArray(float64[366]),). Argument mapping: (UnspecifiedValue,).
WARNING:2025-11-12 15:16:02,977:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(reduce_sum) in 0.007550478 sec
WARNING:2025-11-12 15:16:03,021:jax._src.dispatch:222: Finished XLA compilation of jit(reduce_sum) in 0.043447256 sec
WARNING:2025-11-12 15:16:03,024:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[]), ShapedArray(float64[366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,032:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.007074356 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  18% 5:14:50WARNING:2025-11-12 15:16:03,066:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.034173489 sec
WARNING:2025-11-12 15:16:03,070:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[366])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,077:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.006629467 sec
WARNING:2025-11-12 15:16:03,108:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.031239033 sec
WARNING:2025-11-12 15:16:03,112:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,119:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.006843328 sec
Working... ━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  18% 5:14:50WARNING:2025-11-12 15:16:03,152:jax._src.dispatch:222: Finished XLA compilation of jit(mul) in 0.032195091 sec
WARNING:2025-11-12 15:16:03,156:jax._src.interpreters.pxla:1960: Compiling jit(div) with global shapes and types (ShapedArray(float64[366]), ShapedArray(float64[])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,163:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(div) in 0.006853342 sec
WARNING:2025-11-12 15:16:03,201:jax._src.dispatch:222: Finished XLA compilation of jit(div) in 0.037773132 sec
WARNING:2025-11-12 15:16:03,207:jax._src.interpreters.pxla:1960: Compiling jit(mul) with global shapes and types (ShapedArray(float64[366,2]), ShapedArray(float64[366,2])). Argument mapping: (UnspecifiedValue, UnspecifiedValue).
WARNING:2025-11-12 15:16:03,216:jax._src.dispatch:222: Finished jaxpr to MLIR module conversion jit(mul) in 0.007925987 sec

Thanks again for any insights!

-Louis

Versions

| Package | Version |
| ------- | ------- |
| pertpy  | 1.0.3   |

| Dependency          | Version       |
| ------------------- | ------------- |
| numpy               | 2.3.4         |
| multipledispatch    | 1.0.0 (0.6.0) |
| toolz               | 1.1.0         |
| h5py                | 3.15.1        |
| Pygments            | 2.19.2        |
| scipy               | 1.16.3        |
| llvmlite            | 0.45.1        |
| idna                | 3.11          |
| anndata             | 0.12.6        |
| jax-cuda12-plugin   | 0.8.0         |
| pyparsing           | 3.2.5         |
| ply                 | 3.11          |
| blitzgsea           | 1.3.54        |
| cycler              | 0.12.1        |
| kiwisolver          | 1.4.9         |
| joblib              | 1.5.2         |
| natsort             | 8.4.0         |
| jax-cuda12-pjrt     | 0.8.0         |
| importlib_resources | 6.5.2         |
| certifi             | 2025.11.12    |
| msgpack             | 1.1.2         |
| numpyro             | 0.19.0        |
| equinox             | 0.13.2        |
| statsmodels         | 0.14.5        |
| donfig              | 0.8.1.post1   |
| xarray              | 2025.10.1     |
| opt_einsum          | 3.4.0         |
| fsspec              | 2025.10.0     |
| fast-array-utils    | 1.3           |
| setuptools          | 80.9.0        |
| requests            | 2.32.5        |
| charset-normalizer  | 3.4.4         |
| pillow              | 12.0.0        |
| filelock            | 3.20.0        |
| python-dateutil     | 2.9.0.post0   |
| zarr                | 3.1.3         |
| six                 | 1.17.0        |
| psutil              | 7.1.3         |
| jaxlib              | 0.8.0         |
| legacy-api-wrap     | 1.5           |
| urllib3             | 2.5.0         |
| absl-py             | 2.3.1         |
| rich                | 14.2.0        |
| pandas              | 2.3.3         |
| flax                | 0.12.0        |
| pyarrow             | 22.0.0        |
| simplejson          | 3.20.2        |
| jax                 | 0.8.0         |
| ott-jax             | 0.6.0         |
| typing_extensions   | 4.15.0        |
| scikit-learn        | 1.7.2         |
| optax               | 0.2.6         |
| pyomo               | 6.9.5         |
| crc32c              | 2.8           |
| lineax              | 0.0.8         |
| adjustText          | 1.3.0         |
| PubChemPy           | 1.0.5         |
| sparsecca           | 0.3.1         |
| mudata              | 0.3.2         |
| pytz                | 2025.2        |
| PyYAML              | 6.0.3         |
| scanpy              | 1.11.5        |
| packaging           | 25.0          |
| wadler_lindig       | 0.1.7         |
| threadpoolctl       | 3.6.0         |
| matplotlib          | 3.10.7        |
| seaborn             | 0.13.2        |
| lamin_utils         | 0.15.0        |
| numba               | 0.62.1        |
| chex                | 0.1.91        |
| etils               | 1.13.0        |
| jaxopt              | 0.8.5         |
| numcodecs           | 0.16.3        |
| jaxtyping           | 0.3.3         |
| ml_dtypes           | 0.5.3         |
| scikit-misc         | 0.5.2         |
| tqdm                | 4.67.1        |
| session-info2       | 0.2.3         |
| patsy               | 1.0.2         |
| mpmath              | 1.3.0         |

| Component | Info                                                                              |
| --------- | --------------------------------------------------------------------------------- |
| Python    | 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] |
| OS        | Linux-4.18.0-553.5.1.el8_10.x86_64-x86_64-with-glibc2.28                          |
| Updated   | 2025-11-12 20:27

jdeu1023 avatar Nov 12 '25 20:11 jdeu1023

Thanks! We haven't really performance tested it on GPUs but something seems to be off. It might take a week or two for us to get to it as the scverse conference is also coming up.

Could you please just confirm that other GPU accelerated Jax algorithms (outside of pertpy) are as performant as expected on your HPC?

Zethson avatar Nov 12 '25 21:11 Zethson

Thank you for your rapid response! I will update shortly as I confirm that it's not just a cluster issue!

jdeu1023 avatar Nov 14 '25 21:11 jdeu1023

Hi @Zethson, I believe that the cluster I'm on is indeed using the GPU to speed JAX up! I installed jax[cuda12] into a fresh conda environment, and here's the crux of the test that I ran in that environment to confirm this:

import jax
import jax.numpy as jnp
from jax import random

n = 10000 # matrix dimension
key = random.PRNGKey(0) # using this to get the same matrices for CPU and GPU
a = random.normal(key, (n, n), dtype=jnp.float32)
b = random.normal(key, (n, n), dtype=jnp.float32)

@jax.jit
def matmul(x, y):
  return x @ y

# Warmup to trigger JIT compilation (not timed)
c = matmul(a, b).block_until_ready()

# Timed repetitions
t0 = time.time()
for i in range(reps):
  c = matmul(a, b).block_until_ready()
t1 = time.time()
total = t1 - t0

I ran this with JAX_PLATFORM_NAME=gpu and JAX_PLATFORM_NAME=cpu on the GPU queue. When JAX_PLATFORM_NAME=gpu, it took 0.10 sec for 5 repetitions of this matrix multiplication. When JAX_PLATFORM_NAME=cpu it took 24.26 sec for 5 runs. That seems like confirmation to me that the GPU was used, at least for this test? I'm happy to run additional tests if that would provide more insight!

Best, Louis

jdeu1023 avatar Nov 18 '25 04:11 jdeu1023

Got it, thanks! We'll try to look into it soon. Please remain patient for a bit.

Zethson avatar Nov 18 '25 09:11 Zethson

I'll look into this in about 14 days.

Zethson avatar Dec 02 '25 10:12 Zethson

Great, thank you!

jdeu1023 avatar Dec 03 '25 04:12 jdeu1023