moscot icon indicating copy to clipboard operation
moscot copied to clipboard

Parameters of Temporal Problem to use in atlassing context

Open katelynxli opened this issue 1 year ago • 7 comments

Hi,

I am using TemporalProblem to compute couplings between different timepoints for my atlassing project. In particular, I am working with brain organoid data from multiple time points across >25 datasets.

My adata object is shape 1770582 × 36842. We performed integration using scpoli and have the embedding stored in .obsm.

This is my initial code:

tp = mc.problems.time.TemporalProblem(adata)
tp = tp.prepare(time_key="binned_ages")
tp = tp.solve(epsilon=1e-3, batch_size=64) # not sure where the online mode is

it is running now with the output:

INFO     Solving problem BirthDeathProblem[stage='prepared', shape=(228554, 404688)]. 

This has been running now for over 2.5 days, so it is extremely slow. I am wondering if there are parameters I should change? What is a reasonable batch size to use in this context? And how do I access the "online mode"? I previously set batch_size=None and immediately encountered a GPU Memory error.

Furthermore, should I be feeding the scpoli integrated embedding instead of .X into the function and, if so, how can I do that?

katelynxli avatar May 01 '23 09:05 katelynxli

Hi @katelynxli ,

thanks a lot for your interest in moscot!

This has been running now for over 2.5 days, so it is extremely slow. I am wondering if there are parameters I should change?

yes, this is too long and shouldn't take like that

What is a reasonable batch size to use in this context?

I think as much as you can fit on GPU, but I would start with 20000 and go down if you don't encounter memory errors

And how do I access the "online mode"?

indeed by passing the batch_size it automatically goes in online mode. We should add a note in the docs: https://moscot.readthedocs.io/en/latest/genapi/moscot.problems.time.TemporalProblem.solve.html#moscot.problems.time.TemporalProblem.solve

I previously set batch_size=None and immediately encountered a GPU Memory error.

this is likely, as it did not enter online mode

Furthermore, should I be feeding the scpoli integrated embedding instead of .X into the function and, if so, how can I do that?

yes, you should probably do that, as by default it would compute PCA from adata.X by default. I would suggest to pass in prepare the key in obsm where you have the scpoli embedding, see https://moscot.readthedocs.io/en/latest/genapi/moscot.problems.time.TemporalProblem.prepare.html#moscot.problems.time.TemporalProblem.prepare

let us know if it works and if you encounter any other issue

giovp avatar May 02 '23 14:05 giovp

Hi @giovp,

Thank you so much for your reply, I've tried the above suggestions and it helped!

Unfortunately, the highest batch size I could set without running into gpu memory error was batch_size = 2048. With this, I have been running the tp.solve() function for 24 hours and it has not finished yet.

epsilon = 0.05
tau_a = 0.8  
tau_b = 0.95 

tp = TemporalProblem(adata)  # , solver=SinkhornSolver()
tp.prepare("binned_ages", joint_attr="X_scpoli")

TemporalProblem[(90, 120), (15, 30), (120, 150), (150, 450), (60, 90), (30, 60)]

tp.solve(epsilon=epsilon, tau_a=tau_a, tau_b=tau_b, batch_size=2048, scale_cost='mean')

INFO Solving problem BirthDeathProblem[stage='prepared', shape=(228554, 404688)].

converged = {}
for key, val in tp.solutions.items():
    converged[key] = val.converged
converged

transport_maps = _get_transport_maps(
    time_points=np.sort(adata.obs['binned_ages'].unique()),
    problem=tp,
    adata=adata
)

I am wondering if you can check my code and see if there's anything I am doing wrong, or if this length of run time is expected given the 2 million cells (scpoli embedding has only 10 features). Additionally, do you recommend that I tune the hyperparameters epsilon, tau_a, and tau_b, and if so how? I am running moscot on a Tesla T4 GPU with 15GB of GPU memory.

Thanks! Katelyn

katelynxli avatar May 05 '23 10:05 katelynxli

Hi @katelynxli ,

I think the reason for this is the threshold parameter, which we experienced to be hard to set in the unbalanced case. I would suggest two approaches:

  1. solve in a balanced manner, and see whether you get faster convergence (if so, how fast), i.e. use tau_a=tau_b=1.
  2. increase the threshold parameter, e.g. from 0.001 to 0.005, or even - to try - 0.01.

Moreover, I would recommend using prior information from your proliferaetion and apoptosis rates. Please follow this example, or for more context, this tutorial.

MUCDK avatar May 11 '23 10:05 MUCDK

Thank you so much! I increased the batch_size to 8192 and changed the threshold to 0.01. It is still taking quite some time, and I think it would go much faster if I could increase the batch_size; however, I run out of gpu memory if I go any higher than 8192.

The GPU I am working on has 40 GB of GPU memory. However, I have noticed that just loading in TemporalProblem from moscot.problems.time._lineage, requires 32GB of GPU memory. Is this something that occurs in your hands as well? @giovp @MUCDK

katelynxli avatar May 25 '23 14:05 katelynxli

Thu May 25 16:50:05 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13    Driver Version: 525.60.13    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-PCI...  Off  | 00000000:27:00.0 Off |                    0 |
| N/A   70C    P0   255W / 250W |  31121MiB / 40960MiB |    100%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-PCI...  Off  | 00000000:C3:00.0 Off |                    0 |
| N/A   31C    P0    35W / 250W |    419MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A    137532      C   ...s/trajectories/bin/python    31118MiB |
|    1   N/A  N/A    137532      C   ...s/trajectories/bin/python      416MiB |
+-----------------------------------------------------------------------------+```

katelynxli avatar May 25 '23 14:05 katelynxli

jax by default pre-allocates 90% of the GPU memory, this was changed in https://github.com/ott-jax/ott/pull/353 and will be available in moscot once I make new release, but ultimately it's not what causes your OOM. It just seems your data is too large, here are some suggestions how to speed it up:

  1. using a Gaussian initializer
  2. using momentum/Anderson acceleration (recommend reading this when tuning these parameters)
  3. using a bit on unbalancedness in both source and target (paper explaining this)
from ott.solvers.linear import acceleration

tp = mc.problems.time.TemporalProblem(adata)
tp = tp.prepare(time_key="binned_ages")
tp = tp.solve(
  initializer="gaussian",  # 1.
  momentum=acceleration.Momentum(...),  # 2.
  anderson=acceleration.AndersonAcceleration(...),  # 2.
  recenter_potentials=True,  # 3.
  tau_a=0.999,  # 3.
  tau_b=0.999,  # 3.
  epsilon=1e-2,
  batch_size=8192,
)

In principle, you can combine all 3 approaches, but would start 1-by-1 and see whether there's an improvement in runtime (personally would start with momentum, then unbalanced and then the initializer). Please also take a look at this tutorial that shows how to track the progress. See also https://github.com/theislab/moscot/issues/544 for how to pass some of the arguments.

Alternatively, you can try passing rank=500 to run a low-rank solver (the above-mentioned suggestions are not applicable here), which has better time complexity (linear vs. quadratic).

michalk8 avatar May 25 '23 17:05 michalk8

Thank you! This makes sense, and I'll try this and report back.

katelynxli avatar May 25 '23 17:05 katelynxli