scanpy icon indicating copy to clipboard operation
scanpy copied to clipboard

scale function(_get_mean_var) updated for dense array, speedup upto ~4.65x

Open ashish615 opened this issue 1 year ago • 5 comments

Hi, We are submitting PR for speed up of the _get_mean_var function.

Time(sec)
Original 18.49
Updated 3.97
Speedup 4.65743073

experiment setup : AWS r7i.24xlarge

import time
import numpy as np

import pandas as pd

import scanpy as sc
from sklearn.cluster import KMeans

import os
import wget

import warnings



warnings.filterwarnings('ignore', 'Expected ')
warnings.simplefilter('ignore')
input_file = "./1M_brain_cells_10X.sparse.h5ad"

if not os.path.exists(input_file):
    print('Downloading import file...')
    wget.download('https://rapids-single-cell-examples.s3.us-east-2.amazonaws.com/1M_brain_cells_10X.sparse.h5ad',input_file)


# marker genes
MITO_GENE_PREFIX = "mt-" # Prefix for mitochondrial genes to regress out
markers = ["Stmn2", "Hes1", "Olig1"] # Marker genes for visualization

# filtering cells
min_genes_per_cell = 200 # Filter out cells with fewer genes than this expressed
max_genes_per_cell = 6000 # Filter out cells with more genes than this expressed

# filtering genes
min_cells_per_gene = 1 # Filter out genes expressed in fewer cells than this
n_top_genes = 4000 # Number of highly variable genes to retain

# PCA
n_components = 50 # Number of principal components to compute

# t-SNE
tsne_n_pcs = 20 # Number of principal components to use for t-SNE

# k-means
k = 35 # Number of clusters for k-means

# Gene ranking

ranking_n_top_genes = 50 # Number of differential genes to compute for each cluster

# Number of parallel jobs
sc._settings.ScanpyConfig.n_jobs = os.cpu_count()

start=time.time()
tr=time.time()
adata = sc.read(input_file)
adata.var_names_make_unique()
adata.shape
print("Total read time : %s" % (time.time()-tr))



tr=time.time()
# To reduce the number of cells:
USE_FIRST_N_CELLS = 1300000
adata = adata[0:USE_FIRST_N_CELLS]
adata.shape

sc.pp.filter_cells(adata, min_genes=min_genes_per_cell)
sc.pp.filter_cells(adata, max_genes=max_genes_per_cell)
sc.pp.filter_genes(adata, min_cells=min_cells_per_gene)
sc.pp.normalize_total(adata, target_sum=1e4)
print("Total filter and normalize time : %s" % (time.time()-tr))


tr=time.time()
sc.pp.log1p(adata)
print("Total log time : %s" % (time.time()-tr))


# Select highly variable genes
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor = "cell_ranger")

# Retain marker gene expression
for marker in markers:
        adata.obs[marker + "_raw"] = adata.X[:, adata.var.index == marker].toarray().ravel()

# Filter matrix to only variable genes
adata = adata[:, adata.var.highly_variable]

ts=time.time()
#Regress out confounding factors (number of counts, mitochondrial gene expression)
mito_genes = adata.var_names.str.startswith(MITO_GENE_PREFIX)
n_counts = np.array(adata.X.sum(axis=1))
adata.obs['percent_mito'] = np.array(np.sum(adata[:, mito_genes].X, axis=1)) / n_counts
adata.obs['n_counts'] = n_counts


sc.pp.regress_out(adata, ['n_counts', 'percent_mito'])
print("Total regress out time : %s" % (time.time()-ts))

#scale

ts=time.time()
sc.pp.scale(adata)
print("Total scale time : %s" % (time.time()-ts))

add timer around _get_mean_var call

https://github.com/scverse/scanpy/blob/706d4ef65e5d65e04b788831e7fd65dbe6b2a61f/scanpy/preprocessing/_scale.py#L167

we can also create _get_mean_var_std function that return std as well so we don't require to compute it in scale function(L168-L169).

ashish615 avatar Jun 05 '24 09:06 ashish615

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 76.31%. Comparing base (896e249) to head (7a1a62e). Report is 127 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3099      +/-   ##
==========================================
- Coverage   76.31%   76.31%   -0.01%     
==========================================
  Files         109      109              
  Lines       12513    12516       +3     
==========================================
+ Hits         9549     9551       +2     
- Misses       2964     2965       +1     
Files with missing lines Coverage Δ
src/scanpy/preprocessing/_utils.py 95.12% <100.00%> (-2.25%) :arrow_down:

codecov[bot] avatar Jun 05 '24 09:06 codecov[bot]

Benchmark changes

Change Before [ad657edf] After [e7a46626] Ratio Benchmark (Parameter)
+ 259M 310M 1.2 preprocessing_log.FastSuite.peakmem_mean_var('pbmc68k_reduced')
+ 1.16±0.04ms 1.97±0.5ms 1.69 preprocessing_log.FastSuite.time_mean_var('pbmc68k_reduced')
+ 255M 315M 1.23 preprocessing_log.peakmem_highly_variable_genes('pbmc68k_reduced')
- 373M 322M 0.86 preprocessing_log.peakmem_pca('pbmc68k_reduced')
- 1.03G 779M 0.76 preprocessing_log.peakmem_scale('pbmc3k')
- 729±5ms 517±5ms 0.71 preprocessing_log.time_scale('pbmc3k')

Comparison: https://github.com/scverse/scanpy/compare/ad657edfb52e9957b9a93b3a16fc8a87852f3f09..e7a466265b08f6973a5cf3fecfc27879104c02f4 Last changed:

More details: https://github.com/scverse/scanpy/pull/3099/checks?check_run_id=26384736173

scverse-benchmark[bot] avatar Jun 17 '24 12:06 scverse-benchmark[bot]

I have some small improvements that I would like to add next week for more precision for larger matrices

Intron7 avatar Jun 20 '24 14:06 Intron7

@ashish615 after doing some benchmarking myself I found out that your solution for axis=1 is under performing compared to axis=0 for larger arrays. I think that is because of the memory access pattern you choose. I rewrote the function with that in mind. I'll again make a PR to you, because for some reason you disallow us from making changes to your PR.

Intron7 avatar Jun 26 '24 10:06 Intron7

The function should also work for 1 thread. numba.get_num_threads() is fine it works well with the sparse arrays. But I have no experience with it inside of dask.

Intron7 avatar Jun 27 '24 13:06 Intron7