GSVA
GSVA copied to clipboard
Speed up ssGSEA and reduce memory by moving .fastRndWalk to C++
Hello,
For ssGSEA on scRNA-seq data it appears the code is running the .fastRndWalk
function n_cells * gene_sets
number of times. I was curious whether moving this function to C++ could speed up this operation (and potentially make it more memory efficient) so I roughly reimplemented the function using Rcpp with a few minor changes.
Hijacking your vignette code to make some example data.
library("Rcpp")
library("GSVA")
library("BiocParallel")
set.seed(100)
p <- 20000 ## number of genes
n <- 100 ## number of samples (cells)
## simulate expression values from a standard Gaussian distribution
X <- matrix(rnorm(p*n), nrow=p,
dimnames=list(paste0("g", 1:p), paste0("s", 1:n)))
X <- as(X, "CsparseMatrix")
## sample gene set sizes
gs <- as.list(sample(10:100, size=100, replace=TRUE))
## sample gene sets
gs <- lapply(gs, function(n, p)
paste0("g", sample(1:p, size=n, replace=FALSE)), p)
names(gs) <- paste0("gs", 1:length(gs))
Preparing the data to run the old and new functions.
X <- GSVA:::.filterFeatures(X, "ssgsea")
geneSets <- GSVA:::.mapGeneSetsToFeatures(gs, rownames(X))
n <- ncol(X)
R <- t(sparseMatrixStats::colRanks(X, ties.method = "average"))
mode(R) <- "integer"
Ra <- abs(R)^0.25
The R implementation of .fastRndWalk
.
.fastRndWalk <- function(gSetIdx, geneRanking, j, Ra) {
n <- length(geneRanking)
k <- length(gSetIdx)
idxs <- sort.int(match(gSetIdx, geneRanking))
stepCDFinGeneSet2 <-
sum(Ra[geneRanking[idxs], j] * (n - idxs + 1)) /
sum((Ra[geneRanking[idxs], j]))
stepCDFoutGeneSet2 <- (n * (n + 1) / 2 - sum(n - idxs + 1)) / (n - k)
walkStat <- stepCDFinGeneSet2 - stepCDFoutGeneSet2
walkStat
}
R_fastRndWalk <- function(){
es <- bplapply(as.list(1:n), function(j) {
geneRanking <- order(R[, j], decreasing=TRUE)
es_sample <- lapply(geneSets, .fastRndWalk, geneRanking, j, Ra)
unlist(es_sample)
}, BPPARAM=SerialParam(progressbar=TRUE))
es <- do.call("cbind", es)
return(es)
}
Here's the Rcpp implementation of fasterRndWalk
.
sourceCpp(code="
#include <Rcpp.h>
using namespace Rcpp;
// [[Rcpp::export]]
double fasterRndWalk(IntegerVector gSetIdx, IntegerVector geneRanking, int j, NumericMatrix Ra) {
int n = geneRanking.size();
int k = gSetIdx.size();
IntegerVector idxs = match(gSetIdx, geneRanking) - 1;
double sum1 = 0;
double sum2 = 0;
for (int i = 0; i < k; ++i) {
int idx = idxs[i];
double value = Ra(geneRanking[idx] - 1, j - 1);
sum1 += value * (n - idx);
sum2 += value;
}
double stepCDFinGeneSet2 = sum1 / sum2;
double stepCDFoutGeneSet2 = (n * (n + 1) / 2 - sum(n - idxs + 1)) / (n - k);
double walkStat = stepCDFinGeneSet2 - stepCDFoutGeneSet2;
return walkStat;
}
")
Rcpp_fasterRndWalk <- function() {
es <- bplapply(as.list(1:n), function(j) {
geneRanking <- order(R[, j], decreasing=TRUE)
es_sample <- lapply(geneSets, fasterRndWalk, geneRanking, j, Ra)
unlist(es_sample)
}, BPPARAM=SerialParam(progressbar=TRUE))
es <- do.call("cbind", es)
return(es)
}
Benchmarking the two implementations.
bench::mark(
R_fastRndWalk(),
Rcpp_fasterRndWalk(),
time_unit="s",
iterations=10,
check=FALSE)
# A tibble: 2 × 13
expression min median `itr/sec` mem_alloc `gc/sec` n_itr n_gc total_time result memory time gc
<bch:expr> <dbl> <dbl> <dbl> <bch:byt> <dbl> <int> <dbl> <dbl> <list> <list> <list> <list>
1 R_fastRndWalk() 14.4 14.4 0.0684 3.25GB 0.212 10 31 146. <NULL> <Rprofmem [128,647 × 3]> <bench_tm [10]> <tibble [10 × 3]>
2 Rcpp_fasterRndWalk() 3.71 3.75 0.267 30.67MB 0.160 10 6 37.4 <NULL> <Rprofmem [22,813 × 3]> <bench_tm [10]> <tibble [10 × 3]>
The C++ implementation is almost 4 times faster and uses about 100 times less memory.
The results are slightly different.
> R_fastRndWalk()[1:10, 1:10]
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
gs1 1460.4123 786.42436 848.96567 1652.6371 1699.5721 868.81145 99.72889 1147.2828 403.7456 951.345237
gs2 1351.3376 1589.24939 1166.46396 -813.0662 1193.3584 1436.63892 -671.25817 1470.0535 1757.5663 1528.896167
gs3 989.5343 -724.26180 988.38412 1285.1941 1725.1857 755.75058 413.20320 1216.5852 1103.2568 860.702336
gs4 2270.6517 913.43945 1786.67177 1630.7358 832.9891 888.71093 1349.89766 1286.1466 1130.9348 347.736033
gs5 965.6410 1770.05765 -718.08133 631.8339 1105.8057 2099.18587 853.27931 1738.7601 -361.8615 1696.477819
gs6 1241.1693 605.35762 1390.54474 218.5366 1603.1661 1064.22024 738.78739 1321.9661 1595.9738 866.650390
gs7 1821.4801 1105.39881 1805.02746 676.4591 738.2390 1670.38658 800.48911 1655.6888 1616.6367 1087.332425
gs8 -1150.0557 3321.83607 -297.52957 2636.5804 2193.6280 1574.25666 1273.75154 693.2889 918.0266 2620.212948
gs9 966.0466 39.91441 -81.57265 -301.6164 1266.5278 751.08846 846.78107 1121.3221 348.3009 -2.232159
gs10 1701.1213 2497.79930 1937.79704 224.0872 2921.6033 -65.45541 1132.45168 1988.8786 482.1426 1195.779765
> Rcpp_fasterRndWalk()[1:10, 1:10]
[,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9] [,10]
gs1 1461.2588 787.11592 849.46221 1653.3682 1700.3948 869.71002 100.5005 1147.6113 404.5719 951.448069
gs2 1351.7445 1590.10813 1167.39739 -812.7176 1193.5546 1436.89530 -670.6217 1470.8496 1758.2667 1529.485934
gs3 990.3381 -723.26908 989.34011 1285.4300 1725.7057 756.63913 413.2240 1217.1778 1103.9727 861.649998
gs4 2271.2896 914.25031 1787.42785 1631.2987 833.0586 889.30922 1350.6158 1286.2566 1131.0374 348.201812
gs5 966.2519 1770.93258 -718.06508 631.9691 1105.8603 2099.70185 853.5787 1739.6593 -361.5559 1696.634394
gs6 1242.0753 605.57559 1391.00445 219.5044 1603.2600 1064.83204 739.1218 1322.4162 1596.2087 867.207822
gs7 1822.1339 1105.67133 1805.62463 677.3429 738.5166 1670.98240 801.3700 1655.7198 1617.6060 1087.722861
gs8 -1149.3512 3321.87337 -297.04575 2637.5285 2194.2728 1574.87523 1273.8755 693.7584 918.8056 2620.856900
gs9 966.5787 40.34541 -80.57416 -301.3876 1267.3939 751.96272 846.8289 1121.3280 348.5236 -1.692659
gs10 1701.4705 2498.41092 1938.07949 224.7069 2922.3907 -64.56492 1132.6352 1989.6011 483.1052 1195.850440
My C++ is rusty (because of Rust) and I know very little C, so I imagine someone else could improve this further or reimplement it in C and avoid any more dependencies. I'm not too proud to admit that I needed ChatGPT to debug a line of code for me here.
Some relevant versions.
> R.Version()$version.string
[1] "R version 4.2.1 (2022-06-23)"
> packageVersion("GSVA")
[1] ‘1.46.0’
> packageVersion("Rcpp")
[1] ‘1.0.9’
> packageVersion("BiocParallel")
[1] ‘1.32.5’
Cheers, Bob