scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

Custom dataloader registry support

Open ori-kron-wis opened this issue 1 year ago • 1 comments

ori-kron-wis avatar Aug 07 '24 12:08 ori-kron-wis

Codecov Report

:x: Patch coverage is 81.67614% with 129 lines in your changes missing coverage. Please review. :white_check_mark: Project coverage is 80.16%. Comparing base (ced87df) to head (67caa96). :warning: Report is 70 commits behind head on main.

Files with missing lines Patch % Lines
src/scvi/model/base/_base_model.py 49.63% 69 Missing :warning:
src/scvi/dataloaders/_custom_dataloders.py 91.59% 29 Missing :warning:
src/scvi/model/base/_archesmixin.py 82.22% 8 Missing :warning:
src/scvi/model/base/_training_mixin.py 76.47% 8 Missing :warning:
src/scvi/model/_scanvi.py 86.66% 4 Missing :warning:
src/scvi/model/base/_rnamixin.py 93.33% 4 Missing :warning:
src/scvi/model/base/_vaemixin.py 77.77% 4 Missing :warning:
src/scvi/data/_utils.py 57.14% 3 Missing :warning:
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #2932      +/-   ##
==========================================
+ Coverage   80.12%   80.16%   +0.04%     
==========================================
  Files         196      197       +1     
  Lines       17570    18156     +586     
==========================================
+ Hits        14078    14555     +477     
- Misses       3492     3601     +109     
Files with missing lines Coverage Δ
src/scvi/dataloaders/__init__.py 100.00% <100.00%> (ø)
src/scvi/dataloaders/_data_splitting.py 95.47% <ø> (ø)
src/scvi/model/_scvi.py 96.42% <100.00%> (+0.51%) :arrow_up:
src/scvi/model/base/_save_load.py 83.49% <100.00%> (+1.38%) :arrow_up:
src/scvi/train/_trainingplans.py 85.73% <100.00%> (+0.41%) :arrow_up:
src/scvi/data/_utils.py 85.00% <57.14%> (-1.13%) :arrow_down:
src/scvi/model/_scanvi.py 91.17% <86.66%> (-1.85%) :arrow_down:
src/scvi/model/base/_rnamixin.py 94.17% <93.33%> (-0.36%) :arrow_down:
src/scvi/model/base/_vaemixin.py 89.13% <77.77%> (+1.17%) :arrow_up:
src/scvi/model/base/_archesmixin.py 78.20% <82.22%> (+1.31%) :arrow_up:
... and 3 more
:rocket: New features to boost your workflow:
  • :snowflake: Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

codecov[bot] avatar Aug 11 '24 11:08 codecov[bot]

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Hi Ori, I gave it a try to this branch. First, I needed to install psutil. Second, when I ran the tutorial, I found an error:

Code:

model = scvi.model.SCVI(adata=None, registry=datamodule.registry, datamodule=datamodule) Traceback (most recent call last): File "", line 1, in File "/allen/programs/celltypes/workgroups/rnaseqanalysis/Mariano/Anaconda3/envs/scvi-env-largedata/lib/python3.12/site-packages/scvi/model/_scvi.py", line 184, in init self.module = self._module_cls( ^^^^^^^^^^^^^^^^^ TypeError: VAE.init() got an unexpected keyword argument 'datamodule'

marianogabitto avatar Apr 16 '25 17:04 marianogabitto

I fixed it.

In the tutorial file: scvi-tools/docs/user_guide/use_case /custom_dataloaders.md

Replace the line: model = scvi.model.SCVI(adata=None, registry=datamodule.registry, datamodule=datamodule)

with: model = scvi.model.SCVI(adata=None, registry=datamodule.registry)

marianogabitto avatar Apr 16 '25 17:04 marianogabitto

@marianogabitto thanks. I fixed those things. Also see the test_custom_dataloder.py file and the corresponding tutorial on this to know the full capabilities we have currently on c.dataloders (https://github.com/scverse/scvi-tutorials/pull/425).

As you also saw, my concern right now is the slowness of the process, in the data loading between batches. Let's focus on this. Did you find/ have any suggestions on how we could make it faster?

feel free to fork the branch and add your commits.

ori-kron-wis avatar Apr 21 '25 06:04 ori-kron-wis

Hi Ori, I am adding the point that I brought in discourse: I compared head to head the TileDB and the regular anndata loaders and Tile is 50% slower. I tested the TileDB dataloader ir regular and DDP mode and what is causing the delay is the slow access to data. GPU peaks and process super fast but in between batches there is a long waiting time

Let me tell you how I did this comparison. On the one hand, I save the TileDB experiment as an anndata and I run scvi regularly. On the other hand, I grab the data loader code from scvi (splitting, Anndataloader, AnnDataset) and create an scvi external AnnData loader (just to be sure there was no difference between running the anndata in regular scvi versus passing it as a new anndataloader).

2 questions: 1) If you can wait until Wednesday, I am going to talk to people in my High Performance Computer Cluster and representatives from TileDB and try to debug if there is any way to accelerate the data loader with multiple workers or options. It is expected that the data loader that access disk is slower, I want to see how slow. 2) what is the data loader that is used in regular scvi ? Should I also use the following? wouldn't the train dataloder only have an amount of data equivalent to train_size ? inference_dataloader = ( inference_datamodule.on_before_batch_transfer(batch, None) for batch in inference_datamodule.train_dataloader() )

M

marianogabitto avatar Apr 21 '25 06:04 marianogabitto

Just checking is it still faster if you load the AmnData in disk backed mode? This would indeed be surprising while the other overhead could come from loading from disk? Any chance to use a fast SSD storage of the data as usually recommended for from-disk loading. Data could also be scattered across SSDs and there tileDB might have suggestions how to optimize this - not perfect randomness is not really an issue if they first randomize order on disks (not a single experiment on one SSD).

canergen avatar Apr 21 '25 06:04 canergen

Speeds are: anndata in memory > anndata backed >~ tileDB .

yes, I am testing two things: I am talking with my support team about a partition with fast disk access and touching base with tiled about how to optimize this.

marianogabitto avatar Apr 21 '25 06:04 marianogabitto

@marianogabitto I will try to do some testing on my end. Actually, the whole part of "on_before_batch_transfer" was something that I inherited from the census implementation long ago, but I never had the time to check it fully (see https://github.com/chanzuckerberg/cellxgene-census/blob/756708e9aa18791b7bae3712e9dd66d2b6ce9d75/api/python/notebooks/experimental/pytorch_scvi.ipynb) and it makes sense to me, it might be the bottleneck. We are not in a rush to release it. we can wait for the inputs from support team.

ori-kron-wis avatar Apr 21 '25 06:04 ori-kron-wis

But you need the "on_before_batch" to move the batch to PyTorch, no ? isn't the reason of that callback to get the X, batch, labels data and to move it to tensors? what happens if they are already tensors?

marianogabitto avatar Apr 21 '25 07:04 marianogabitto

One more... how about making the anndata_manager and the registry tutorial ?

marianogabitto avatar Apr 21 '25 07:04 marianogabitto

We will deal with everything, but let's take it step by step. on_before_batch is needed of course, but lets see how we can improve it.

ori-kron-wis avatar Apr 21 '25 07:04 ori-kron-wis

Your suggestion worked and I do see much better performance in train time when using

datamodule.setup()
model = scvi.model.SCVI(
    adata=None,
    registry=datamodule.registry,
    n_layers=n_layers,
    n_latent=n_latent,
    gene_likelihood="nb",
    encode_covariates=False,
)
# creating the dataloader for trainset
training_dataloader = (
    datamodule.on_before_batch_transfer(batch, None)
    for batch in datamodule.train_dataloader()
)
import time
start = time.time()
model.train(
    datamodule=training_dataloader,
    #datamodule=datamodule,
    max_epochs=100,
    batch_size=1024,
    # accelerator="gpu",
    # devices=-1,
    # strategy="ddp_find_unused_parameters_true",
)
end = time.time()
print(f"Elapsed time: {end - start:.2f} seconds")

See my updated tutorial on the other branch

ori-kron-wis avatar Apr 21 '25 12:04 ori-kron-wis

Ori, I got lost ... what branch should I look for the tutorial? What branch should I check for testing your commits?

marianogabitto avatar Apr 23 '25 16:04 marianogabitto

This are the tutorials: https://app.reviewnb.com/scverse/scvi-tutorials/pull/425/

tests are in the current branch

ori-kron-wis avatar Apr 23 '25 16:04 ori-kron-wis

Ori, this is not working for me. When I invoke in the notebook: training_dataloader = ( datamodule.on_before_batch_transfer(batch, None) for batch in datamodule.train_dataloader() ) I get: switching torch multiprocessing start method from "fork" to "spawn" and then errors out....

marianogabitto avatar Apr 25 '25 06:04 marianogabitto

Ori, all the examples that I am listing below are run by removing the code ".on_before_batch_transfer()". The way I posted before.

  1. When num_workers=0, I can train with low speeds, defined as below. When I fix number of workers, like num_workers=4,12 or 24. The trainer takes forever to initialize and then is even slower than below.

  2. Can you monitor your GPU usage with nvitop or nvtop? Let me tell you my head-to-head comparisons.

  • TIleDB from cell census. I believe that this is reading from S3, so it is never actually copy data to disk. It takes 120 sec/it to train. I see GPU activity almost zero all the time except at moments when it picks to 100%.

  • TIleDB from anndata created from the query . This is reading from a local disk directory. It takes 11 sec/it to train. I see GPU activity almost zero all the time except at moments when it picks to 100%.

  • Regular way of loading anndata into memory. It takes 1.2 sec/it to train. I see GPU activity at 40% all the time.

These led me to believe that we are not loading data into GPU memory fast enough.

  1. I forgot to tell you but the TileDB representative send me this as reference. It is different from the way we run because they launch the processes. https://github.com/single-cell-data/TileDB-SOMA-ML/tree/rw/cli/src/tiledbsoma_ml/cli#example-invocation

marianogabitto avatar Apr 25 '25 06:04 marianogabitto

Hi @marianogabitto , Thanks

  1. I made several changes, and I added the on_before_batch_transform into the class, it is not part of the analysis code now. So if you pulled the branch and reinstalled, you will get errors for running the same code as before. I have updated the tutorials (see there), sorry for this.

But im not following on your code, can you share what you are running exactly so we can compare?

  1. Regarding the GPU behavior. I see it the same. The data you use matters (the speed enhancement is seen in larger data, not smaller ones).

I dont think GPU is not utilized, it just that the data load is much slower in tiledb, as you said 100 times slower in that sense. so while with adata the data loading is 1s we see almost continuous use of the GPU and in the tileDB s3 there's a 100sec gap between the same GPU usage, so we mistakenly see it underuse.

  1. Num workers is for multiprocessing loading and is a parameter in the torch dataloader. We know that it is also dependent on data size and that we do not always get what we expect from it, specifically, there is overhead with initializing and closing it. How do you use it? I will try to check its speed in the custom dataloder context also, in any case, we should run it with the number that best benefits us, it's not a magic thing that helps each time.

  2. I think the common practice is a 1 GPU running on a notebook. We need to make sure this is working, and other scenarios will follow. But having said that, the scripts of running with DDP compared to running them in notebooks can be very different, and we need to test all possibilities. We might find that we need to run it as a script like this reference you gave. Will check.

  3. I added SCANVI to the tutorials as well, some issues still exists in the prediction part for tiledb

ori-kron-wis avatar Apr 29 '25 15:04 ori-kron-wis

Ori, I am testing updates in 12 hours. Sorry for the delay. One more thing in the meantime. It will be great to expose the scvi Anndata DataLoader as an example of what is going on internally. This code does not work because the BatchDistributedSampler is not outputting the samples with the correct dimensions (In DDP), but if you help me solve it, it will be great.

Code

from scvi.dataloaders import DataSplitter

scvi.model.SCVI.setup_anndata(adata, batch_key="batch", categorical_covariate_keys=['cell_type', 'donor']) ad_manager = scvi.model.SCVI._get_most_recent_anndata_manager(adata, required=True)

model = scvi.model.SCVI( registry=ad_manager._registry, gene_likelihood="nb", encode_covariates=False, )

ad_manager.adata = adata dl = DataSplitter(ad_manager, train_size=0.9, pin_memory=True, num_workers=2, persistent_workers=True)#, prefetch_factor=2) dl.setup()

model.train( datamodule=datamodule, max_epochs=10, batch_size=128, train_size=0.9, early_stopping=False, accelerator="gpu", devices=-1, strategy="ddp_find_unused_parameters_true", )

marianogabitto avatar Apr 30 '25 22:04 marianogabitto

Hi @marianogabitto , Your code above should work in multiGPU settings, just add distributed_sampler=True to the DataSplitter call

Besides that I made several other updates for this PR, census/lamin custom dataloaders should be working now for scvi/scnavi/scarches/load/save/multiGPU/covariates integration

ori-kron-wis avatar May 12 '25 12:05 ori-kron-wis