moscot
moscot copied to clipboard
change `batch_size` in `sparsify`
Needed for keeping solver fast (high batch size), but prevent OOM in sparsify
(need low batch size), see https://github.com/theislab/cellrank/issues/1146#issuecomment-1856213099
hey @giovp @MUCDK, I ran some quick benchmark to capture the peak memory.
From what I checked maximum memory allocated for solve
is (batch_size_point_cloud,d
), (n,1
), or (m,1
) and apply_lse_kernel
contributes to these allocations. Meanwhile apply
uses apply_lse_kernel
with vmap so that it's run batch_size_sparse
times and stacks its result. Gets n,1
for each apply_lse_kernel
and returns (n,batch_size_sparse
).
If m=n
, these should be the memory complexities in theory:
-
solve
memory complexity isO(max(n,batch_size_point_cloud*d))
-
apply
memory complexity isO(max(n*batch_size_sparse,batch_size_point_cloud*d)
however as in the linked issue apply
tries to allocate (batch_size,d,m)
. I think this is due to the vmap usage in apply. When I call jax.make_jaxpr(solver)(ot_prob)
I get at most 2d array shapes while if I run jax.make_jaxpr(ot.apply)(jnp.eye(19, n))
I get 3d shaped arrays. (you can also run these after solving in this notebook https://github.com/ott-jax/ott/blob/main/docs/tutorials/point_clouds.ipynb)
For the benchmark setup.
I ran with n, m, d = 1400, 1700, 400
and max_iterations=2
for solve.
[33.33%] ··· benchmarks.PointCloud.peakmem_apply1 ok
[33.33%] ··· =============== ======= =======
-- batch_size_sp
--------------- ---------------
batch_size_pc 400 600
=============== ======= =======
400 1.19G 1.52G
120 650M 799M
=============== ======= =======
[66.67%] ··· benchmarks.PointCloud.peakmem_apply2 ok
[66.67%] ··· =============== ======= =======
-- batch_size_sp
--------------- ---------------
batch_size_pc 400 600
=============== ======= =======
400 1.99G 2.24G
120 1.03G 1.38G
=============== ======= =======
[100.00%] ··· benchmarks.PointCloud.peakmem_solve ok
[100.00%] ··· =============== ====== ======
-- batch_size_sp
--------------- -------------
batch_size_pc 400 600
=============== ====== ======
400 401M 402M
120 384M 394M
=============== ====== ======
Here is the code for benchmark (I ran it with asv run --quick --python=same
)
https://gist.github.com/selmanozleyen/70d3ed29aa7841bcaa41f18165f64ab5
Great, thanks. what is batch_size_sparse
?
So seems like solve
and apply
require the same memory?
Could you please check by setting batch_size_sparse
to 1?
So seems like
solve
andapply
require the same memory? Could you please check by settingbatch_size_sparse
to 1?
No, solve doesn't exceed 500mb. Here are the results with batch_size_sp=1, its similar to solve
[33.33%] ··· benchmarks.PointCloud.peakmem_apply1 ok
[33.33%] ··· =============== ======= ======= ======
-- batch_size_sp
--------------- ----------------------
batch_size_pc 400 600 1
=============== ======= ======= ======
400 1.31G 1.77G 375M
120 656M 795M 377M
=============== ======= ======= ======
[66.67%] ··· benchmarks.PointCloud.peakmem_apply2 ok
[66.67%] ··· =============== ======= ======= ======
-- batch_size_sp
--------------- ----------------------
batch_size_pc 400 600 1
=============== ======= ======= ======
400 1.64G 2.48G 393M
120 1.05G 1.37G 379M
=============== ======= ======= ======
[100.00%] ··· benchmarks.PointCloud.peakmem_solve ok
[100.00%] ··· =============== ====== ====== ======
-- batch_size_sp
--------------- --------------------
batch_size_pc 400 600 1
=============== ====== ====== ======
400 403M 403M 390M
120 393M 396M 394M
=============== ====== ====== ======
yeah but seems like apply
requires just batch_size_sparse
times more memory, which means that apply
and solve
requires equally much when we apply
only to one vector. Hence, vmap
is compatible with the batch_size
argument, right?
All in all, this means that we require batch_size_sparse
times more memory in sparsify
, right?
yeah but seems like
apply
requires justbatch_size_sparse
times more memory, which means thatapply
andsolve
requires equally much when weapply
only to one vector. Hence,vmap
is compatible with thebatch_size
argument, right?All in all, this means that we require
batch_size_sparse
times more memory insparsify
, right?
Yes, that is correct.
okay, now the question is whether it's faster when we decrease the batch size in ott-jax (i.e. PointCloud), and hence to increase the batch size in sparsify
.
Any chance you could benchmark this?
yep, here are the results
[50.00%] ··· Running (benchmarks.PointCloud.time_apply1--).
[75.00%] ··· benchmarks.PointCloud.peakmem_apply1 ok
[75.00%] ··· =============== ======= ======= =======
-- batch_size_sp
--------------- -----------------------
batch_size_pc 1200 600 120
=============== ======= ======= =======
1200 5.5G 5.22G 1.43G
600 5.44G 5.21G 1.4G
120 1.24G 792M 502M
=============== ======= ======= =======
[100.00%] ··· benchmarks.PointCloud.time_apply1 ok
[100.00%] ··· =============== ============ =========== ===========
-- batch_size_sp
--------------- ------------------------------------
batch_size_pc 1200 600 120
=============== ============ =========== ===========
1200 8.41±0.01s 2.30±0.1s 490±5ms
600 5.24±0s 2.12±0s 491±1ms
120 3.24±0.02s 1.57±0s 422±0.2ms
=============== ============ =========== ===========
Hence, it's faster id we have a large batch size in point_cloud
, and a small batch size in sparsify
, is this correct?
Thus, there doesn't seem to be a need to change anything in the code, as we would not want to decrease the batch_size
in PointCloud
, right?
Thus, if we use batch_size=1
in sparsify
, this would prevent running OOM, and is still the fastest option, even if not great as it takes forever.
Maybe one last thing @selmanozleyen : did you convert the output to a csr_matrix
within the for loop
to simulate what we are doing in sparsify
? As this might require a lot of overhead.
Hence, it's faster id we have a large batch size in
point_cloud
, and a small batch size insparsify
, is this correct? Thus, there doesn't seem to be a need to change anything in the code, as we would not want to decrease thebatch_size
inPointCloud
, right?
yes for these values of m,n,d.
Thus, if we use
batch_size=1
insparsify
, this would prevent running OOM, and is still the fastest option, even if not great as it takes forever.
For the whole sparsify method, I am not sure if batch_size=1
would be the best but it seems like the smaller the better. Also is it really that slow? Is there a chance that on these cases the data was on gpu? because the sparsify seems to implicitly copy from gpu to cpu each iteration of the for loop, which is usually very slow. Unless I got something wrong a warning might be a good idea if the data is on gpu.
Maybe one last thing @selmanozleyen : did you convert the output to a
csr_matrix
within thefor loop
to simulate what we are doing insparsify
? As this might require a lot of overhead.
No this was just for apply.
Btw I tried this comparison as you told me. The difference is so much and the one with cpu begining takes so much that it times out of 6 minutes while other is done in 5 seconds. I just think that the slowness of the computation on cpu exceeds the cost copying from gpu to cpu. I can share more details once our clusters are faster, I did this quickly on colab
def time_sparsify_cpu_from_start(self, *args, **kwargs):
for (t1, t2), solution in self.res.solutions.items():
solution = solution.to(device='cpu')
solution = solution.sparsify(mode='min_row', batch_size=self.batch_size_sp)
def time_sparsify_cpu_implicit(self, *args, **kwargs):
for (t1, t2), solution in self.res.solutions.items():
solution = solution.to(device='cuda')
solution = solution.sparsify(mode='min_row', batch_size=self.batch_size_sp)
@giovp @MUCDK
Update: It timeouts even for most cases (when it's moved to cpu before). The results doesn't accumulate on the gpu ram anyway so I think when applying sparsify
the solution should be on gpu and if it OOM's one should reduce the batch_size of sparsify as much as possible. First moving to cpu
makes this call way costlier, maybe that is why it takes too much time.