rapids_singlecell icon indicating copy to clipboard operation
rapids_singlecell copied to clipboard

(feat): `_first_pass_qc` single dispatch refactor

Open ilan-gold opened this issue 2 months ago • 1 comments

I added a separate file here containing the refactor. It drops the number of lines by 50, and that's only first pass (second would be more).

To reproduce my benchmark, I ran (on our current cluster setup):

CUDA_VISIBLE_DEVICES="0,1,2,3" srun --gres=gpu:4 --partition=dc-gpu --account=training2406 --reservation=gpuhack24-2024-04-25 --time=480  --pty /bin/bash -i

and then the following script in a file called rsc_example.py via python rsc_example.py refactored or python rsc_example.py current. This script prints out the time taken and the AnnData object (so one can see the QC is calculated). For me, I get about 800-900 milliseconds on both implementations.

import time
import sys

import anndata

import dask
from dask_cuda import LocalCUDACluster
from dask.distributed import Client

import cudf
from cuml.dask.common.part_utils import _extract_partitions
import cupy as cp
from cupyx.scipy import sparse

import h5py
import rapids_singlecell as rsc
import rmm
from rmm.allocators.cupy import rmm_cupy_allocator



def set_mem():
    rmm.reinitialize(managed_memory=True)
    cp.cuda.set_allocator(rmm_cupy_allocator)

def read_with_filter(client,
                     sample_file, batch_size = 50000):
    """
    Reads an h5ad file and applies cell and geans count filter. Dask Array is
    used allow partitioning the input file. This function supports multi-GPUs.
    """

    # Path in h5 file
    _data = '/X/data'
    _index = '/X/indices'
    _indprt = '/X/indptr'
    # _genes = '/var/ensembl_ids'
    #_genes = '/var/ensembl_id'
    _genes = '/var/_index'
    #_genes = '/var/feature_id'
    _barcodes = '/obs/_index'

    @dask.delayed
    def _read_partition_to_sparse_matrix(sample_file,
                                         total_cols, batch_start, batch_end,
                                         ):
        with h5py.File(sample_file, 'r') as h5f:
            indptrs = h5f[_indprt]
            start_ptr = indptrs[batch_start]
            end_ptr = indptrs[batch_end]

            # Read all things data and index
            sub_data = cp.array(h5f[_data][start_ptr:end_ptr])
            sub_indices = cp.array(h5f[_index][start_ptr:end_ptr])

            # recompute the row pointer for the partial dataset
            sub_indptrs  = cp.array(indptrs[batch_start:(batch_end + 1)])
            sub_indptrs = sub_indptrs - sub_indptrs[0]

        # Reconstruct partial sparse array
        partial_sparse_array = cp.sparse.csr_matrix(
            (sub_data, sub_indices, sub_indptrs),
            shape=(batch_end - batch_start, total_cols))
            
        return partial_sparse_array


    with h5py.File(sample_file, 'r') as h5f:
        # Compute the number of cells to read
        indptr = h5f[_indprt]
        vars= h5f["/var/"]
        print(vars.keys())
        genes = cudf.Series(h5f[_genes], dtype=cp.dtype('object'))

        total_cols = genes.shape[0]
        max_cells = indptr.shape[0] - 1

    dls = []
    for batch_start in range(0, max_cells, batch_size):
        actual_batch_size = min(batch_size, max_cells - batch_start)
        dls.append(dask.array.from_delayed(
                   (_read_partition_to_sparse_matrix)
                   (sample_file,
                    total_cols,
                    batch_start,
                    batch_start + actual_batch_size),
                   dtype=cp.float32,
                   meta=sparse.csr_matrix(cp.array((1.,))),
                   shape=(actual_batch_size, total_cols)))

    dask_sparse_arr =  dask.array.concatenate(dls)
    dask_sparse_arr = dask_sparse_arr.persist()
    return dask_sparse_arr, genes

if __name__ == '__main__':
    preprocessing_gpus="0, 1, 2, 3"
    cluster = LocalCUDACluster(CUDA_VISIBLE_DEVICES=preprocessing_gpus)
    client = Client(cluster)
    set_mem()
    client.run(set_mem)
        
    dask_sparse_arr, genes = read_with_filter(
        client, 
        "/p/scratch/training2406/team_scverse/scverse_data/1M_brain_cells_10X.sparse.h5ad", 
        batch_size=50000
    )
    dask_sparse_arr = dask_sparse_arr.persist()

    dask_sparse_arr.compute_chunk_sizes()
    adata = anndata.AnnData(dask_sparse_arr)
    rsc.pp.flag_gene_family(adata, gene_family_name="MT", gene_family_prefix="mt-")

    start = time.time()
    funcs = {
        "refactored": rsc.pp.calculate_qc_metrics_refactored,
        "current": rsc.pp.calculate_qc_metrics
    }
    
    funcs[sys.argv[1]](adata, qc_vars = "MT",client=client)
    print('TIME TAKEN:', time.time() - start)
    print('QCed ANNDATA:', adata)
    client.retire_workers()
    client.shutdown()

ilan-gold avatar Apr 25 '24 14:04 ilan-gold