GSVA icon indicating copy to clipboard operation
GSVA copied to clipboard

Speed up ssGSEA and reduce memory by moving .fastRndWalk to C++

Open rpolicastro opened this issue 1 year ago • 1 comments

Hello,

For ssGSEA on scRNA-seq data it appears the code is running the .fastRndWalk function n_cells * gene_sets number of times. I was curious whether moving this function to C++ could speed up this operation (and potentially make it more memory efficient) so I roughly reimplemented the function using Rcpp with a few minor changes.

Hijacking your vignette code to make some example data.

library("Rcpp")
library("GSVA")
library("BiocParallel")

set.seed(100)

p <- 20000 ## number of genes
n <- 100    ## number of samples (cells)
## simulate expression values from a standard Gaussian distribution
X <- matrix(rnorm(p*n), nrow=p,
            dimnames=list(paste0("g", 1:p), paste0("s", 1:n)))

X <- as(X, "CsparseMatrix")

## sample gene set sizes
gs <- as.list(sample(10:100, size=100, replace=TRUE))
## sample gene sets
gs <- lapply(gs, function(n, p)
                   paste0("g", sample(1:p, size=n, replace=FALSE)), p)
names(gs) <- paste0("gs", 1:length(gs))

Preparing the data to run the old and new functions.

X <- GSVA:::.filterFeatures(X, "ssgsea")

geneSets <- GSVA:::.mapGeneSetsToFeatures(gs, rownames(X))

n <- ncol(X)

R <- t(sparseMatrixStats::colRanks(X, ties.method = "average"))
mode(R) <- "integer"

Ra <- abs(R)^0.25

The R implementation of .fastRndWalk.

.fastRndWalk <- function(gSetIdx, geneRanking, j, Ra) {
    n <- length(geneRanking)
    k <- length(gSetIdx)
    idxs <- sort.int(match(gSetIdx, geneRanking))
    
    stepCDFinGeneSet2 <- 
        sum(Ra[geneRanking[idxs], j] * (n - idxs + 1)) /
        sum((Ra[geneRanking[idxs], j]))    
    
    
    stepCDFoutGeneSet2 <- (n * (n + 1) / 2 - sum(n - idxs + 1)) / (n - k)
    
    walkStat <- stepCDFinGeneSet2 - stepCDFoutGeneSet2

    walkStat
}

R_fastRndWalk <- function(){
  es <- bplapply(as.list(1:n), function(j) {
    geneRanking <- order(R[, j], decreasing=TRUE)
    es_sample <- lapply(geneSets, .fastRndWalk, geneRanking, j, Ra)
    
    unlist(es_sample)
  }, BPPARAM=SerialParam(progressbar=TRUE))
  es <- do.call("cbind", es)
  return(es)
}

Here's the Rcpp implementation of fasterRndWalk.

sourceCpp(code="
  #include <Rcpp.h>
  using namespace Rcpp;
  
  // [[Rcpp::export]]
  double fasterRndWalk(IntegerVector gSetIdx, IntegerVector geneRanking, int j, NumericMatrix Ra) {
    int n = geneRanking.size();
    int k = gSetIdx.size();
    IntegerVector idxs = match(gSetIdx, geneRanking) - 1;
    
    double sum1 = 0;
    double sum2 = 0;
    for (int i = 0; i < k; ++i) {
      int idx = idxs[i];
      double value = Ra(geneRanking[idx] - 1, j - 1);
      sum1 += value * (n - idx);
      sum2 += value;
    }
    
    double stepCDFinGeneSet2 = sum1 / sum2;
    double stepCDFoutGeneSet2 = (n * (n + 1) / 2 - sum(n - idxs + 1)) / (n - k);
    double walkStat = stepCDFinGeneSet2 - stepCDFoutGeneSet2;
    
    return walkStat;
  }
")

Rcpp_fasterRndWalk <- function() {
  es <- bplapply(as.list(1:n), function(j) {
    geneRanking <- order(R[, j], decreasing=TRUE)
    es_sample <- lapply(geneSets, fasterRndWalk, geneRanking, j, Ra)
    
    unlist(es_sample)
  }, BPPARAM=SerialParam(progressbar=TRUE))
  es <- do.call("cbind", es)
  return(es)
}

Benchmarking the two implementations.

bench::mark(
  R_fastRndWalk(),
  Rcpp_fasterRndWalk(),
  time_unit="s",
  iterations=10,
  check=FALSE)

# A tibble: 2 × 13
  expression             min median `itr/sec` mem_alloc `gc/sec` n_itr  n_gc total_time result memory                   time            gc               
  <bch:expr>           <dbl>  <dbl>     <dbl> <bch:byt>    <dbl> <int> <dbl>      <dbl> <list> <list>                   <list>          <list>           
1 R_fastRndWalk()      14.4   14.4     0.0684    3.25GB    0.212    10    31      146.  <NULL> <Rprofmem [128,647 × 3]> <bench_tm [10]> <tibble [10 × 3]>
2 Rcpp_fasterRndWalk()  3.71   3.75    0.267    30.67MB    0.160    10     6       37.4 <NULL> <Rprofmem [22,813 × 3]>  <bench_tm [10]> <tibble [10 × 3]>

The C++ implementation is almost 4 times faster and uses about 100 times less memory.

The results are slightly different.

> R_fastRndWalk()[1:10, 1:10]
           [,1]       [,2]       [,3]      [,4]      [,5]       [,6]       [,7]      [,8]      [,9]       [,10]
gs1   1460.4123  786.42436  848.96567 1652.6371 1699.5721  868.81145   99.72889 1147.2828  403.7456  951.345237
gs2   1351.3376 1589.24939 1166.46396 -813.0662 1193.3584 1436.63892 -671.25817 1470.0535 1757.5663 1528.896167
gs3    989.5343 -724.26180  988.38412 1285.1941 1725.1857  755.75058  413.20320 1216.5852 1103.2568  860.702336
gs4   2270.6517  913.43945 1786.67177 1630.7358  832.9891  888.71093 1349.89766 1286.1466 1130.9348  347.736033
gs5    965.6410 1770.05765 -718.08133  631.8339 1105.8057 2099.18587  853.27931 1738.7601 -361.8615 1696.477819
gs6   1241.1693  605.35762 1390.54474  218.5366 1603.1661 1064.22024  738.78739 1321.9661 1595.9738  866.650390
gs7   1821.4801 1105.39881 1805.02746  676.4591  738.2390 1670.38658  800.48911 1655.6888 1616.6367 1087.332425
gs8  -1150.0557 3321.83607 -297.52957 2636.5804 2193.6280 1574.25666 1273.75154  693.2889  918.0266 2620.212948
gs9    966.0466   39.91441  -81.57265 -301.6164 1266.5278  751.08846  846.78107 1121.3221  348.3009   -2.232159
gs10  1701.1213 2497.79930 1937.79704  224.0872 2921.6033  -65.45541 1132.45168 1988.8786  482.1426 1195.779765

> Rcpp_fasterRndWalk()[1:10, 1:10]
           [,1]       [,2]       [,3]      [,4]      [,5]       [,6]      [,7]      [,8]      [,9]       [,10]
gs1   1461.2588  787.11592  849.46221 1653.3682 1700.3948  869.71002  100.5005 1147.6113  404.5719  951.448069
gs2   1351.7445 1590.10813 1167.39739 -812.7176 1193.5546 1436.89530 -670.6217 1470.8496 1758.2667 1529.485934
gs3    990.3381 -723.26908  989.34011 1285.4300 1725.7057  756.63913  413.2240 1217.1778 1103.9727  861.649998
gs4   2271.2896  914.25031 1787.42785 1631.2987  833.0586  889.30922 1350.6158 1286.2566 1131.0374  348.201812
gs5    966.2519 1770.93258 -718.06508  631.9691 1105.8603 2099.70185  853.5787 1739.6593 -361.5559 1696.634394
gs6   1242.0753  605.57559 1391.00445  219.5044 1603.2600 1064.83204  739.1218 1322.4162 1596.2087  867.207822
gs7   1822.1339 1105.67133 1805.62463  677.3429  738.5166 1670.98240  801.3700 1655.7198 1617.6060 1087.722861
gs8  -1149.3512 3321.87337 -297.04575 2637.5285 2194.2728 1574.87523 1273.8755  693.7584  918.8056 2620.856900
gs9    966.5787   40.34541  -80.57416 -301.3876 1267.3939  751.96272  846.8289 1121.3280  348.5236   -1.692659
gs10  1701.4705 2498.41092 1938.07949  224.7069 2922.3907  -64.56492 1132.6352 1989.6011  483.1052 1195.850440

My C++ is rusty (because of Rust) and I know very little C, so I imagine someone else could improve this further or reimplement it in C and avoid any more dependencies. I'm not too proud to admit that I needed ChatGPT to debug a line of code for me here.

Some relevant versions.

> R.Version()$version.string
[1] "R version 4.2.1 (2022-06-23)"
> packageVersion("GSVA")
[1] ‘1.46.0’
> packageVersion("Rcpp")
[1] ‘1.0.9’
> packageVersion("BiocParallel")
[1] ‘1.32.5’

Cheers, Bob

rpolicastro avatar Mar 20 '23 18:03 rpolicastro