moscot icon indicating copy to clipboard operation
moscot copied to clipboard

change `batch_size` in `sparsify`

Open MUCDK opened this issue 1 year ago • 11 comments

Needed for keeping solver fast (high batch size), but prevent OOM in sparsify (need low batch size), see https://github.com/theislab/cellrank/issues/1146#issuecomment-1856213099

MUCDK avatar Dec 14 '23 20:12 MUCDK

hey @giovp @MUCDK, I ran some quick benchmark to capture the peak memory.

From what I checked maximum memory allocated for solve is (batch_size_point_cloud,d), (n,1), or (m,1) and apply_lse_kernel contributes to these allocations. Meanwhile apply uses apply_lse_kernel with vmap so that it's run batch_size_sparse times and stacks its result. Gets n,1 for each apply_lse_kernel and returns (n,batch_size_sparse).

If m=n, these should be the memory complexities in theory:

  • solve memory complexity is O(max(n,batch_size_point_cloud*d))
  • apply memory complexity is O(max(n*batch_size_sparse,batch_size_point_cloud*d)

however as in the linked issue apply tries to allocate (batch_size,d,m). I think this is due to the vmap usage in apply. When I call jax.make_jaxpr(solver)(ot_prob) I get at most 2d array shapes while if I run jax.make_jaxpr(ot.apply)(jnp.eye(19, n)) I get 3d shaped arrays. (you can also run these after solving in this notebook https://github.com/ott-jax/ott/blob/main/docs/tutorials/point_clouds.ipynb)

For the benchmark setup.

I ran with n, m, d = 1400, 1700, 400 and max_iterations=2 for solve.

[33.33%] ··· benchmarks.PointCloud.peakmem_apply1                                ok
[33.33%] ··· =============== ======= =======
             --               batch_size_sp 
             --------------- ---------------
              batch_size_pc    400     600  
             =============== ======= =======
                   400        1.19G   1.52G 
                   120         650M    799M 
             =============== ======= =======

[66.67%] ··· benchmarks.PointCloud.peakmem_apply2                                ok
[66.67%] ··· =============== ======= =======
             --               batch_size_sp 
             --------------- ---------------
              batch_size_pc    400     600  
             =============== ======= =======
                   400        1.99G   2.24G 
                   120        1.03G   1.38G 
             =============== ======= =======

[100.00%] ··· benchmarks.PointCloud.peakmem_solve                                 ok
[100.00%] ··· =============== ====== ======
              --              batch_size_sp
              --------------- -------------
               batch_size_pc   400    600  
              =============== ====== ======
                    400        401M   402M 
                    120        384M   394M 
              =============== ====== ======

Here is the code for benchmark (I ran it with asv run --quick --python=same) https://gist.github.com/selmanozleyen/70d3ed29aa7841bcaa41f18165f64ab5

selmanozleyen avatar Mar 06 '24 14:03 selmanozleyen

Great, thanks. what is batch_size_sparse?

MUCDK avatar Mar 06 '24 15:03 MUCDK

So seems like solve and apply require the same memory? Could you please check by setting batch_size_sparse to 1?

MUCDK avatar Mar 06 '24 15:03 MUCDK

So seems like solve and apply require the same memory? Could you please check by setting batch_size_sparse to 1?

No, solve doesn't exceed 500mb. Here are the results with batch_size_sp=1, its similar to solve

[33.33%] ··· benchmarks.PointCloud.peakmem_apply1                                ok
[33.33%] ··· =============== ======= ======= ======
             --                  batch_size_sp     
             --------------- ----------------------
              batch_size_pc    400     600     1   
             =============== ======= ======= ======
                   400        1.31G   1.77G   375M 
                   120         656M    795M   377M 
             =============== ======= ======= ======

[66.67%] ··· benchmarks.PointCloud.peakmem_apply2                                ok
[66.67%] ··· =============== ======= ======= ======
             --                  batch_size_sp     
             --------------- ----------------------
              batch_size_pc    400     600     1   
             =============== ======= ======= ======
                   400        1.64G   2.48G   393M 
                   120        1.05G   1.37G   379M 
             =============== ======= ======= ======

[100.00%] ··· benchmarks.PointCloud.peakmem_solve                                 ok
[100.00%] ··· =============== ====== ====== ======
              --                 batch_size_sp    
              --------------- --------------------
               batch_size_pc   400    600     1   
              =============== ====== ====== ======
                    400        403M   403M   390M 
                    120        393M   396M   394M 
              =============== ====== ====== ======


selmanozleyen avatar Mar 06 '24 15:03 selmanozleyen

yeah but seems like apply requires just batch_size_sparse times more memory, which means that apply and solve requires equally much when we apply only to one vector. Hence, vmap is compatible with the batch_size argument, right?

All in all, this means that we require batch_size_sparse times more memory in sparsify , right?

MUCDK avatar Mar 06 '24 15:03 MUCDK

yeah but seems like apply requires just batch_size_sparse times more memory, which means that apply and solve requires equally much when we apply only to one vector. Hence, vmap is compatible with the batch_size argument, right?

All in all, this means that we require batch_size_sparse times more memory in sparsify , right?

Yes, that is correct.

selmanozleyen avatar Mar 06 '24 15:03 selmanozleyen

okay, now the question is whether it's faster when we decrease the batch size in ott-jax (i.e. PointCloud), and hence to increase the batch size in sparsify.

Any chance you could benchmark this?

MUCDK avatar Mar 06 '24 15:03 MUCDK

yep, here are the results

[50.00%] ··· Running (benchmarks.PointCloud.time_apply1--).
[75.00%] ··· benchmarks.PointCloud.peakmem_apply1                               ok
[75.00%] ··· =============== ======= ======= =======
             --                   batch_size_sp     
             --------------- -----------------------
              batch_size_pc    1200    600     120  
             =============== ======= ======= =======
                   1200        5.5G   5.22G   1.43G 
                   600        5.44G   5.21G    1.4G 
                   120        1.24G    792M    502M 
             =============== ======= ======= =======

[100.00%] ··· benchmarks.PointCloud.time_apply1                                  ok
[100.00%] ··· =============== ============ =========== ===========
              --                         batch_size_sp            
              --------------- ------------------------------------
               batch_size_pc      1200         600         120    
              =============== ============ =========== ===========
                    1200       8.41±0.01s   2.30±0.1s    490±5ms  
                    600         5.24±0s      2.12±0s     491±1ms  
                    120        3.24±0.02s    1.57±0s    422±0.2ms 
              =============== ============ =========== ===========

selmanozleyen avatar Mar 08 '24 10:03 selmanozleyen

Hence, it's faster id we have a large batch size in point_cloud, and a small batch size in sparsify, is this correct? Thus, there doesn't seem to be a need to change anything in the code, as we would not want to decrease the batch_size in PointCloud, right?

Thus, if we use batch_size=1in sparsify, this would prevent running OOM, and is still the fastest option, even if not great as it takes forever.

Maybe one last thing @selmanozleyen : did you convert the output to a csr_matrix within the for loop to simulate what we are doing in sparsify? As this might require a lot of overhead.

MUCDK avatar Mar 08 '24 11:03 MUCDK

Hence, it's faster id we have a large batch size in point_cloud, and a small batch size in sparsify, is this correct? Thus, there doesn't seem to be a need to change anything in the code, as we would not want to decrease the batch_size in PointCloud, right?

yes for these values of m,n,d.

Thus, if we use batch_size=1in sparsify, this would prevent running OOM, and is still the fastest option, even if not great as it takes forever.

For the whole sparsify method, I am not sure if batch_size=1 would be the best but it seems like the smaller the better. Also is it really that slow? Is there a chance that on these cases the data was on gpu? because the sparsify seems to implicitly copy from gpu to cpu each iteration of the for loop, which is usually very slow. Unless I got something wrong a warning might be a good idea if the data is on gpu.

Maybe one last thing @selmanozleyen : did you convert the output to a csr_matrix within the for loop to simulate what we are doing in sparsify? As this might require a lot of overhead.

No this was just for apply.

selmanozleyen avatar Mar 08 '24 12:03 selmanozleyen

Btw I tried this comparison as you told me. The difference is so much and the one with cpu begining takes so much that it times out of 6 minutes while other is done in 5 seconds. I just think that the slowness of the computation on cpu exceeds the cost copying from gpu to cpu. I can share more details once our clusters are faster, I did this quickly on colab

def time_sparsify_cpu_from_start(self, *args, **kwargs):
        for (t1, t2), solution in self.res.solutions.items():
            solution = solution.to(device='cpu')
            solution = solution.sparsify(mode='min_row', batch_size=self.batch_size_sp)
        
def time_sparsify_cpu_implicit(self, *args, **kwargs):
      for (t1, t2), solution in self.res.solutions.items():
            solution = solution.to(device='cuda')
            solution = solution.sparsify(mode='min_row', batch_size=self.batch_size_sp)

@giovp @MUCDK Update: It timeouts even for most cases (when it's moved to cpu before). The results doesn't accumulate on the gpu ram anyway so I think when applying sparsify the solution should be on gpu and if it OOM's one should reduce the batch_size of sparsify as much as possible. First moving to cpu makes this call way costlier, maybe that is why it takes too much time.

selmanozleyen avatar Mar 12 '24 14:03 selmanozleyen