moscot
moscot copied to clipboard
Parameters of Temporal Problem to use in atlassing context
Hi,
I am using TemporalProblem to compute couplings between different timepoints for my atlassing project. In particular, I am working with brain organoid data from multiple time points across >25 datasets.
My adata object is shape 1770582 × 36842. We performed integration using scpoli and have the embedding stored in .obsm.
This is my initial code:
tp = mc.problems.time.TemporalProblem(adata)
tp = tp.prepare(time_key="binned_ages")
tp = tp.solve(epsilon=1e-3, batch_size=64) # not sure where the online mode is
it is running now with the output:
INFO Solving problem BirthDeathProblem[stage='prepared', shape=(228554, 404688)].
This has been running now for over 2.5 days, so it is extremely slow. I am wondering if there are parameters I should change? What is a reasonable batch size to use in this context? And how do I access the "online mode"? I previously set batch_size=None and immediately encountered a GPU Memory error.
Furthermore, should I be feeding the scpoli integrated embedding instead of .X into the function and, if so, how can I do that?
Hi @katelynxli ,
thanks a lot for your interest in moscot!
This has been running now for over 2.5 days, so it is extremely slow. I am wondering if there are parameters I should change?
yes, this is too long and shouldn't take like that
What is a reasonable batch size to use in this context?
I think as much as you can fit on GPU, but I would start with 20000
and go down if you don't encounter memory errors
And how do I access the "online mode"?
indeed by passing the batch_size
it automatically goes in online mode. We should add a note in the docs: https://moscot.readthedocs.io/en/latest/genapi/moscot.problems.time.TemporalProblem.solve.html#moscot.problems.time.TemporalProblem.solve
I previously set batch_size=None and immediately encountered a GPU Memory error.
this is likely, as it did not enter online mode
Furthermore, should I be feeding the scpoli integrated embedding instead of .X into the function and, if so, how can I do that?
yes, you should probably do that, as by default it would compute PCA from adata.X by default. I would suggest to pass in prepare the key in obsm
where you have the scpoli embedding, see https://moscot.readthedocs.io/en/latest/genapi/moscot.problems.time.TemporalProblem.prepare.html#moscot.problems.time.TemporalProblem.prepare
let us know if it works and if you encounter any other issue
Hi @giovp,
Thank you so much for your reply, I've tried the above suggestions and it helped!
Unfortunately, the highest batch size I could set without running into gpu memory error was batch_size = 2048
. With this, I have been running the tp.solve() function for 24 hours and it has not finished yet.
epsilon = 0.05
tau_a = 0.8
tau_b = 0.95
tp = TemporalProblem(adata) # , solver=SinkhornSolver()
tp.prepare("binned_ages", joint_attr="X_scpoli")
TemporalProblem[(90, 120), (15, 30), (120, 150), (150, 450), (60, 90), (30, 60)]
tp.solve(epsilon=epsilon, tau_a=tau_a, tau_b=tau_b, batch_size=2048, scale_cost='mean')
INFO Solving problem BirthDeathProblem[stage='prepared', shape=(228554, 404688)].
converged = {}
for key, val in tp.solutions.items():
converged[key] = val.converged
converged
transport_maps = _get_transport_maps(
time_points=np.sort(adata.obs['binned_ages'].unique()),
problem=tp,
adata=adata
)
I am wondering if you can check my code and see if there's anything I am doing wrong, or if this length of run time is expected given the 2 million cells (scpoli embedding has only 10 features). Additionally, do you recommend that I tune the hyperparameters epsilon, tau_a, and tau_b, and if so how? I am running moscot on a Tesla T4 GPU with 15GB of GPU memory.
Thanks! Katelyn
Hi @katelynxli ,
I think the reason for this is the threshold
parameter, which we experienced to be hard to set in the unbalanced case. I would suggest two approaches:
- solve in a balanced manner, and see whether you get faster convergence (if so, how fast), i.e. use
tau_a=tau_b=1
. - increase the threshold parameter, e.g. from
0.001
to0.005
, or even - to try -0.01
.
Moreover, I would recommend using prior information from your proliferaetion and apoptosis rates. Please follow this example, or for more context, this tutorial.
Thank you so much! I increased the batch_size
to 8192
and changed the threshold
to 0.01
. It is still taking quite some time, and I think it would go much faster if I could increase the batch_size; however, I run out of gpu memory if I go any higher than 8192
.
The GPU I am working on has 40 GB of GPU memory. However, I have noticed that just loading in TemporalProblem
from moscot.problems.time._lineage
, requires 32GB of GPU memory. Is this something that occurs in your hands as well? @giovp @MUCDK
Thu May 25 16:50:05 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.13 Driver Version: 525.60.13 CUDA Version: 12.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-PCI... Off | 00000000:27:00.0 Off | 0 |
| N/A 70C P0 255W / 250W | 31121MiB / 40960MiB | 100% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
| 1 NVIDIA A100-PCI... Off | 00000000:C3:00.0 Off | 0 |
| N/A 31C P0 35W / 250W | 419MiB / 40960MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| 0 N/A N/A 137532 C ...s/trajectories/bin/python 31118MiB |
| 1 N/A N/A 137532 C ...s/trajectories/bin/python 416MiB |
+-----------------------------------------------------------------------------+```
jax
by default pre-allocates 90% of the GPU memory, this was changed in https://github.com/ott-jax/ott/pull/353 and will be available in moscot
once I make new release, but ultimately it's not what causes your OOM.
It just seems your data is too large, here are some suggestions how to speed it up:
- using a Gaussian initializer
- using momentum/Anderson acceleration (recommend reading this when tuning these parameters)
- using a bit on unbalancedness in both source and target (paper explaining this)
from ott.solvers.linear import acceleration
tp = mc.problems.time.TemporalProblem(adata)
tp = tp.prepare(time_key="binned_ages")
tp = tp.solve(
initializer="gaussian", # 1.
momentum=acceleration.Momentum(...), # 2.
anderson=acceleration.AndersonAcceleration(...), # 2.
recenter_potentials=True, # 3.
tau_a=0.999, # 3.
tau_b=0.999, # 3.
epsilon=1e-2,
batch_size=8192,
)
In principle, you can combine all 3 approaches, but would start 1-by-1 and see whether there's an improvement in runtime (personally would start with momentum, then unbalanced and then the initializer). Please also take a look at this tutorial that shows how to track the progress. See also https://github.com/theislab/moscot/issues/544 for how to pass some of the arguments.
Alternatively, you can try passing rank=500
to run a low-rank solver (the above-mentioned suggestions are not applicable here), which has better time complexity (linear vs. quadratic).
Thank you! This makes sense, and I'll try this and report back.