moscot
moscot copied to clipboard
`.impute` consumes too much memory
I'm trying to call problem.impute()
on a solved (linear) spatial mapping problem of dimensions n_source=17806
(spatial data) by n_target=13298
(single-cell data) for n_genes=2039
. This is just a full-rank Sinkhorn problem with batch_size=None
.
Under the hood, this evaluates:
predictions = [val.to(device=device).pull(gexp_sc, scale_by_marginals=True) for val in self.solutions.values()]
The pull
amounts to a matrix multiplication: prediction = P @ X
for transport matrix of shape 17806 x 13298
and single-cell GEX matrix X
of shape 13298 x 2039
. Thus, the memory bottleneck should be P
, which is stored as float32
and should thus consume around 903 MB of memory. However, the call to impute
fails (see traceback below) as it requests 1.76TiB
of memory. That's because it tries to create an array of shape Shape: f32[2039,17806,13298]
, which is not needed for this operation.
Note that passing a batch size does not help much - let's say I'm passing batch_size=500
, then this would still request an array of shape 2039 x 500 x 13298
, which still requires over 50GB of memory. Also, this this slows down solving the actual OT problem, which would not be necessary from a memory point of view.
I talked to @michalk8 about this and it's probably a vmap
that creates an array of the wrong shape. For now, we could solve this by evaluating the pull batch-wise over small sets of genes. This is inefficient, but would solve the issue for now.
If the transport matrix fits into CPU memory, then the current best way to go about this is materializing the transport matrix before calling impute
:
for key, value in lmp.problems.items():
value.solution.to(device="cpu")
value.set_solution(np.array(value.solution.transport_matrix), overwrite=True)
That prevents the memory issue.
Traceback:
2024-07-05 10:45:20.572529: W external[/tsl/tsl/framework/bfc_allocator.cc:485](http://localhost:53807/tsl/tsl/framework/bfc_allocator.cc#line=484)] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.76TiB (rounded to 1931211837440)requested by op
2024-07-05 10:45:20.572824: W external[/tsl/tsl/framework/bfc_allocator.cc:497](http://localhost:53807/tsl/tsl/framework/bfc_allocator.cc#line=496)] *****_______________________________________________________________________________________________
2024-07-05 10:45:20.572951: E external[/xla/xla/pjrt/pjrt_stream_executor_client.cc:2732](http://localhost:53807/xla/xla/pjrt/pjrt_stream_executor_client.cc#line=2731)] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 1931211837328 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 929.12MiB
constant allocation: 0B
maybe_live_out allocation: 1.76TiB
preallocated temp allocation: 0B
total allocation: 1.76TiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 1.76TiB
Operator: op_name="jit(_where)[/jit](http://localhost:53807/jit)(main)[/select_n](http://localhost:53807/select_n)" source_file="[/cluster/project/treutlein/USERS/mlange/github/moscot-fork/src/moscot/backends/ott/output.py](http://localhost:53807/lab/tree/github/spatial_analysis/analysis/experiments_and_tutorials/github/moscot-fork/src/moscot/backends/ott/output.py)" source_line=177
XLA Label: fusion
Shape: f32[2039,17806,13298]
==========================
Buffer 2:
Size: 903.26MiB
Entry Parameter Subshape: f32[17806,13298]
==========================
Buffer 3:
Size: 25.86MiB
Entry Parameter Subshape: pred[2039,1,13298]
==========================
Buffer 4:
Size: 4B
Entry Parameter Subshape: f32[]
==========================