scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

GPU memory overflow (leak?) for large datasets

Open sjfleming opened this issue 5 years ago • 12 comments

We love scVI and use it all the time for batch-effect correction / dataset harmonization. As we scale to larger and larger datasets, we seem to have hit a wall with scVI's scalability. Datasets with more than ~800k cells (using 2000 highly variable genes) never seem to complete training. I seem to be getting out-of-memory errors from the GPU memory late in training. This is strange, since I would not expect GPU memory usage to grow during training itself.

An error only occurs if you have a dataset with maybe 800k cells or more. Here is a sketch of the script that I run (not the whole thing) to do batch-effect correction:

import pandas as pd
import numpy as np

import scanpy as sc

import scvi
from scvi.dataset import Dataset10X
from scvi.dataset.anndataset import AnnDatasetFromAnnData
from scvi.models import VAE
from scvi.inference import UnsupervisedTrainer

# load h5ad
adata = sc.read_h5ad(filename=args.file)

# batch info
batch = np.array(adata.obs['scvi_batch'].values, dtype=np.double).flatten()

# ensure the batch numbers are in order starting at zero
lookup = dict(zip(np.unique(batch), np.arange(0, np.unique(batch).size)))
batch = np.array([lookup[b] for b in batch], dtype=np.double)

# subset adata_hvg to highly variable genes
adata_hvg = adata.copy()
adata_hvg._inplace_subset_var(adata.var['highly_variable'].values)

# create scvi Dataset object
scvi.set_verbosity('WARNING')
scvi_dataset = AnnDatasetFromAnnData(adata_hvg)

# training params
n_epochs = args.max_epochs
lr = 1e-3
use_cuda = True

# train the model and output model likelihood every epoch
vae = VAE(scvi_dataset.nb_genes,
          n_batch=scvi_dataset.n_batches,
          n_latent=args.latent_dim,
          dispersion=dispersion)
trainer = UnsupervisedTrainer(vae,
                              scvi_dataset,
                              train_size=0.9,
                              use_cuda=use_cuda,
                              frequency=1,
                              early_stopping_kwargs={'early_stopping_metric': 'elbo',
                                                     'save_best_state_metric': 'elbo',
                                                     'patience': 20,
                                                     'threshold': 0.5})

print('Running scVI...')

trainer.train(n_epochs=n_epochs, lr=lr)

print('scVI complete.')

# obtain latents from scvi
posterior = trainer.create_posterior(trainer.model, scvi_dataset, indices=np.arange(len(scvi_dataset))).sequential()
latent, _, _ = posterior.get_latent()  # https://github.com/YosefLab/scVI/blob/master/tests/notebooks/scanpy_pbmc3k.ipynb

# store scvi latents as part of the AnnData object
adata.obsm['X_scvi'] = latent

Below is an example of the output from a log file. I run scVI with a maximum of 300 epochs, but use early stopping, so it can terminate before that. In this example, we get to epoch 142 and the process is then killed. Because this actually happens during training, and I do not get the scVI complete. message from my script, I know that the memory error was GPU memory and not CPU memory after training had completed.

Linux error: Killed

Starting workflow.
Working on file.h5ad
scvi_script.py --file file.h5ad --latent_dim 50 --umap --max_epochs 300
Dataset for scvi:
GeneExpressionDataset object with n_cells x nb_genes = 729211 x 2364
gene_attribute_names: 'highly_variable', 'gene_names'
cell_attribute_names: 'local_means', 'batch_indices', 'scvi_batch', 'local_vars', 'labels'
cell_categorical_attribute_names: 'batch_indices', 'labels'
Running scVI...
training: 0%| | 0/300 [00:00<?, ?it/s] training: 0%| | 1/300 [02:18<11:30:33, 138.58s/it] training: 1%| | 2/300 [04:37<11:29:01, 138.73s/it] training: 1%| | 3/300 [06:54<11:23:30, 138.08s/it] training: 1%|▏ | 4/300 [09:13<11:22:59, 138.45s/it] training: 2%|▏ | 5/300 [11:32<11:21:48, 138.67s/it] training: 2%|▏ | 6/300 [13:49<11:16:16, 138.02s/it] training: 2%|▏ | 7/300 [16:03<11:08:58, 136.99s/it] training: 3%|▎ | 8/300 [18:21<11:07:15, 137.11s/it] training: 3%|▎ | 9/300 [20:36<11:02:53, 136.68s/it] training: 3%|▎ | 10/300 [22:55<11:03:29, 137.28s/it] training: 4%|▎ | 11/300 [25:10<10:57:34, 136.52s/it] training: 4%|▍ | 12/300 [27:28<10:57:01, 136.88s/it] training: 4%|▍ | 13/300 [29:45<10:56:17, 137.20s/it] training: 5%|▍ | 14/300 [32:04<10:55:22, 137.49s/it] training: 5%|▌ | 15/300 [34:19<10:50:05, 136.86s/it] training: 5%|▌ | 16/300 [36:37<10:49:39, 137.25s/it] training: 6%|▌ | 17/300 [38:52<10:44:27, 136.63s/it] training: 6%|▌ | 18/300 [41:12<10:46:21, 137.52s/it] training: 6%|▋ | 19/300 [43:34<10:49:57, 138.78s/it] training: 7%|▋ | 20/300 [45:52<10:47:08, 138.67s/it] training: 7%|▋ | 21/300 [48:10<10:43:18, 138.35s/it] training: 7%|▋ | 22/300 [50:30<10:43:23, 138.86s/it] training: 8%|▊ | 23/300 [52:51<10:45:00, 139.71s/it] training: 8%|▊ | 24/300 [55:08<10:38:25, 138.79s/it] training: 8%|▊ | 25/300 [57:21<10:28:01, 137.02s/it] training: 9%|▊ | 26/300 [59:36<10:23:10, 136.46s/it] training: 9%|▉ | 27/300 [1:02:01<10:32:14, 138.95s/it] training: 9%|▉ | 28/300 [1:04:25<10:36:37, 140.43s/it] training: 10%|▉ | 29/300 [1:06:47<10:36:28, 140.92s/it] training: 10%|█ | 30/300 [1:09:08<10:34:11, 140.93s/it] training: 10%|█ | 31/300 [1:11:26<10:28:16, 140.14s/it] training: 11%|█ | 32/300 [1:13:42<10:20:53, 139.00s/it] training: 11%|█ | 33/300 [1:16:00<10:16:34, 138.56s/it] training: 11%|█▏ | 34/300 [1:18:21<10:17:45, 139.35s/it] training: 12%|█▏ | 35/300 [1:20:40<10:15:08, 139.28s/it] training: 12%|█▏ | 36/300 [1:22:54<10:05:51, 137.70s/it] training: 12%|█▏ | 37/300 [1:25:10<10:00:29, 136.99s/it] training: 13%|█▎ | 38/300 [1:27:29<10:00:43, 137.57s/it] training: 13%|█▎ | 39/300 [1:29:48<10:00:21, 138.01s/it] training: 13%|█▎ | 40/300 [1:32:02<9:52:49, 136.81s/it] training: 14%|█▎ | 41/300 [1:34:23<9:56:40, 138.22s/it] training: 14%|█▍ | 42/300 [1:36:43<9:56:13, 138.66s/it] training: 14%|█▍ | 43/300 [1:39:05<9:58:24, 139.71s/it] training: 15%|█▍ | 44/300 [1:41:22<9:53:13, 139.04s/it] training: 15%|█▌ | 45/300 [1:43:43<9:53:13, 139.58s/it] training: 15%|█▌ | 46/300 [1:45:58<9:44:24, 138.05s/it] training: 16%|█▌ | 47/300 [1:48:17<9:43:02, 138.27s/it] training: 16%|█▌ | 48/300 [1:50:36<9:42:05, 138.59s/it] training: 16%|█▋ | 49/300 [1:52:49<9:33:29, 137.09s/it] training: 17%|█▋ | 50/300 [1:55:05<9:29:46, 136.74s/it] training: 17%|█▋ | 51/300 [1:57:17<9:21:20, 135.26s/it] training: 17%|█▋ | 52/300 [1:59:31<9:16:58, 134.75s/it] training: 18%|█▊ | 53/300 [2:01:48<9:18:04, 135.57s/it] training: 18%|█▊ | 54/300 [2:04:03<9:14:46, 135.31s/it] training: 18%|█▊ | 55/300 [2:06:16<9:10:12, 134.74s/it] training: 19%|█▊ | 56/300 [2:08:32<9:08:30, 134.88s/it] training: 19%|█▉ | 57/300 [2:10:44<9:03:29, 134.20s/it] training: 19%|█▉ | 58/300 [2:12:59<9:02:09, 134.42s/it] training: 20%|█▉ | 59/300 [2:15:19<9:07:06, 136.21s/it] training: 20%|██ | 60/300 [2:17:36<9:05:39, 136.41s/it] training: 20%|██ | 61/300 [2:19:58<9:09:28, 137.94s/it] training: 21%|██ | 62/300 [2:22:12<9:02:45, 136.83s/it] training: 21%|██ | 63/300 [2:24:28<8:59:49, 136.66s/it] training: 21%|██▏ | 64/300 [2:26:45<8:57:32, 136.66s/it] training: 22%|██▏ | 65/300 [2:29:04<8:57:49, 137.32s/it] training: 22%|██▏ | 66/300 [2:31:20<8:53:34, 136.81s/it] training: 22%|██▏ | 67/300 [2:33:38<8:52:46, 137.20s/it] training: 23%|██▎ | 68/300 [2:35:55<8:50:23, 137.17s/it] training: 23%|██▎ | 69/300 [2:38:14<8:50:03, 137.68s/it] training: 23%|██▎ | 70/300 [2:40:29<8:45:01, 136.96s/it] training: 24%|██▎ | 71/300 [2:42:48<8:45:03, 137.57s/it] training: 24%|██▍ | 72/300 [2:45:06<8:42:49, 137.59s/it] training: 24%|██▍ | 73/300 [2:47:21<8:38:05, 136.94s/it] training: 25%|██▍ | 74/300 [2:49:45<8:43:46, 139.05s/it] training: 25%|██▌ | 75/300 [2:52:06<8:43:25, 139.58s/it] training: 25%|██▌ | 76/300 [2:54:19<8:33:57, 137.67s/it] training: 26%|██▌ | 77/300 [2:56:38<8:33:33, 138.18s/it] training: 26%|██▌ | 78/300 [2:59:00<8:35:36, 139.35s/it] training: 26%|██▋ | 79/300 [3:01:24<8:37:40, 140.54s/it] training: 27%|██▋ | 80/300 [3:03:37<8:26:52, 138.24s/it] training: 27%|██▋ | 81/300 [3:05:55<8:24:47, 138.30s/it] training: 27%|██▋ | 82/300 [3:08:12<8:21:31, 138.04s/it] training: 28%|██▊ | 83/300 [3:10:21<8:09:27, 135.33s/it] training: 28%|██▊ | 84/300 [3:12:40<8:10:50, 136.34s/it] training: 28%|██▊ | 85/300 [3:14:56<8:08:02, 136.20s/it] training: 29%|██▊ | 86/300 [3:17:12<8:05:12, 136.04s/it] training: 29%|██▉ | 87/300 [3:19:33<8:08:40, 137.66s/it] training: 29%|██▉ | 88/300 [3:21:55<8:10:41, 138.88s/it] training: 30%|██▉ | 89/300 [3:24:09<8:03:30, 137.49s/it] training: 30%|███ | 90/300 [3:26:28<8:03:08, 138.04s/it] training: 30%|███ | 91/300 [3:28:47<8:01:35, 138.26s/it] training: 31%|███ | 92/300 [3:31:03<7:56:34, 137.47s/it] training: 31%|███ | 93/300 [3:33:27<8:00:43, 139.34s/it] training: 31%|███▏ | 94/300 [3:35:41<7:53:45, 137.99s/it] training: 32%|███▏ | 95/300 [3:38:01<7:53:02, 138.45s/it] training: 32%|███▏ | 96/300 [3:40:14<7:45:39, 136.96s/it] training: 32%|███▏ | 97/300 [3:42:31<7:42:33, 136.72s/it] training: 33%|███▎ | 98/300 [3:44:47<7:39:48, 136.57s/it] training: 33%|███▎ | 99/300 [3:47:02<7:35:54, 136.09s/it] training: 33%|███▎ | 100/300 [3:49:16<7:31:22, 135.41s/it] training: 34%|███▎ | 101/300 [3:51:30<7:28:23, 135.19s/it] training: 34%|███▍ | 102/300 [3:53:49<7:29:44, 136.29s/it] training: 34%|███▍ | 103/300 [3:56:03<7:24:55, 135.51s/it] training: 35%|███▍ | 104/300 [3:58:21<7:25:24, 136.35s/it] training: 35%|███▌ | 105/300 [4:00:40<7:25:43, 137.15s/it] training: 35%|███▌ | 106/300 [4:02:59<7:24:53, 137.60s/it] training: 36%|███▌ | 107/300 [4:05:17<7:22:52, 137.68s/it] training: 36%|███▌ | 108/300 [4:07:41<7:26:49, 139.63s/it] training: 36%|███▋ | 109/300 [4:10:06<7:29:58, 141.35s/it] training: 37%|███▋ | 110/300 [4:12:28<7:28:18, 141.57s/it] training: 37%|███▋ | 111/300 [4:14:51<7:27:17, 142.00s/it] training: 37%|███▋ | 112/300 [4:17:08<7:19:51, 140.38s/it] training: 38%|███▊ | 113/300 [4:19:27<7:16:48, 140.15s/it] training: 38%|███▊ | 114/300 [4:21:46<7:13:19, 139.78s/it] training: 38%|███▊ | 115/300 [4:24:03<7:08:02, 138.83s/it] training: 39%|███▊ | 116/300 [4:26:18<7:01:55, 137.58s/it] training: 39%|███▉ | 117/300 [4:28:32<6:56:48, 136.66s/it] training: 39%|███▉ | 118/300 [4:30:53<6:58:04, 137.82s/it] training: 40%|███▉ | 119/300 [4:33:06<6:51:25, 136.38s/it] training: 40%|████ | 120/300 [4:35:24<6:50:51, 136.95s/it] training: 40%|████ | 121/300 [4:37:43<6:50:23, 137.56s/it] training: 41%|████ | 122/300 [4:39:59<6:46:48, 137.13s/it] training: 41%|████ | 123/300 [4:42:19<6:46:54, 137.94s/it] training: 41%|████▏ | 124/300 [4:44:38<6:45:21, 138.19s/it] training: 42%|████▏ | 125/300 [4:46:56<6:42:52, 138.13s/it] training: 42%|████▏ | 126/300 [4:49:09<6:36:17, 136.65s/it] training: 42%|████▏ | 127/300 [4:51:26<6:34:28, 136.81s/it] training: 43%|████▎ | 128/300 [4:53:42<6:31:11, 136.46s/it] training: 43%|████▎ | 129/300 [4:55:59<6:29:47, 136.77s/it] training: 43%|████▎ | 130/300 [4:58:17<6:28:32, 137.13s/it] training: 44%|████▎ | 131/300 [5:00:36<6:27:52, 137.71s/it] training: 44%|████▍ | 132/300 [5:02:52<6:24:07, 137.19s/it] training: 44%|████▍ | 133/300 [5:05:15<6:26:16, 138.78s/it] training: 45%|████▍ | 134/300 [5:07:38<6:27:32, 140.08s/it] training: 45%|████▌ | 135/300 [5:09:54<6:22:20, 139.04s/it] training: 45%|████▌ | 136/300 [5:12:16<6:21:55, 139.73s/it] training: 46%|████▌ | 137/300 [5:14:32<6:16:53, 138.74s/it] training: 46%|████▌ | 138/300 [5:16:54<6:17:22, 139.77s/it] training: 46%|████▋ | 139/300 [5:19:12<6:13:37, 139.24s/it] training: 47%|████▋ | 140/300 [5:21:25<6:05:56, 137.23s/it] training: 47%|████▋ | 141/300 [5:23:41<6:02:57, 136.97s/it] training: 47%|████▋ | 142/300 [5:25:56<5:58:37, 136.18s/it]
16 Killed python scvi_script.py --file file.h5ad --latent_dim 50 --umap --max_epochs 300

When the number of cells is 500k or 600k, this works just fine!

Versions:

Ubuntu 16.04 scVI 0.6.5 scanpy 1.5.1

Hardware:

Tested on Tesla K80 GPU with 12GB memory, and Tesla P100 GPU with 16GB memory. We see this same failure in both cases.

But the point I wanted to emphasize is that this is not something for which the answer is "more GPU memory"... I think there is some problem happening during training. Why should the memory usage increase during the course of training? Training in mini-batches, what does it matter if the dataset size is 400k or 900k?

sjfleming avatar Jul 21 '20 15:07 sjfleming

Thank you for using scVI. This is a bit strange as GPU memory usage should be constant during training. Is it possible to run with the INFO statements? Maybe it has something to do with the early stopping?

The only other possibility I see is

https://github.com/YosefLab/scVI/blob/dabf88291d15f19d2290026ef9c35504dd20a3ce/scvi/inference/trainer.py#L214

where, self.current_loss should actually be equal to loss.item()

A tangential thought is that I'm surprised it takes at least 140 epochs with that many cells. I'd guess that the ELBO converges around ~20 epochs.

We'll continue to look into it.

adamgayoso avatar Jul 21 '20 17:07 adamgayoso

Also, what version of pytorch are you using? I'd also see in the meantime if updating helps.

adamgayoso avatar Jul 21 '20 17:07 adamgayoso

Hi @sjfleming, thanks for posting this issue.

I agree with Adam, indeed I think it's surprising the early stopping did not converge after much less than 50 iterations. We are looking into this, there might be some part of the computational graph that is not cleared and takes up memory. It should be some side effect we are not aware of. Let us keep in touch.

romain-lopez avatar Jul 24 '20 13:07 romain-lopez

Thanks for your comments. The PyTorch version I'm running is

pytorch 1.5.0

I typically see that we don't reach convergence until out past 100 epochs. Perhaps my early stopping parameters are more stringent than what you usually use? Might be worth noting also that we often have a rather large number of "batches"... anywhere from 10 to 80. The learning curve does flatten out substantially by nearly epoch 50, but there is still enough of a downward trend that it doesn't trigger early stopping, so I just let it run.

Thanks again for looking into this.

sjfleming avatar Jul 25 '20 00:07 sjfleming

Hi @sjfleming -- on our master branch we just merged our new backend which uses Pytorch lightning. You can view the docs using the latest version of the readthedocs page. I'd be surprised if this issue still exists.

adamgayoso avatar Jan 21 '21 21:01 adamgayoso

Remark to self, it would be great to manually check if we can reproduce this bug, as scalability is an important feature of the codebase.

romain-lopez avatar Mar 01 '21 21:03 romain-lopez

should we create a random anndata matrix of large size (1M per 20k) and try to run scvi overnight?

romain-lopez avatar Mar 01 '21 21:03 romain-lopez

yes that's a good idea

adamgayoso avatar Mar 03 '21 20:03 adamgayoso

@khalilouardini would you like to work on this issue? it should be a quick one! Please try to run scVI with either a simulated AnnData or with the 1.3M cells dataset from 10x genomics (that we should have access directly from the codebase) and check that the training complete without troubles

romain-lopez avatar Mar 03 '21 21:03 romain-lopez

Well @khalilouardini you should monitor the memory used, both on a cpu run over time and gpu run over time. Thanks!

adamgayoso avatar Mar 03 '21 23:03 adamgayoso

Ok I'm on it!

khalilouardini avatar Mar 04 '21 00:03 khalilouardini

@watiss can you take a look into this? You can use this callback to the trainer:

https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.gpu_stats_monitor.html

adamgayoso avatar Nov 15 '21 16:11 adamgayoso

Any updates on this?

PedroMilanezAlmeida avatar Feb 22 '23 13:02 PedroMilanezAlmeida

I'm not sure this is still an issue, are you running into this issue @PedroMilanezAlmeida ?

adamgayoso avatar Feb 22 '23 15:02 adamgayoso

For what it's worth, I do suspect the original issue was resolved when scvi was moved to the new pytorch-lightning backend. I haven't personally seen the issue since then, though I admit I haven't tried to do any comprehensive testing.

sjfleming avatar Feb 22 '23 16:02 sjfleming

@adamgayoso

I'm not sure this is still an issue, are you running into this issue @PedroMilanezAlmeida ?

I am not currently running into this issue but need to make a decision about which platform to use moving forward for several very large projects.

PedroMilanezAlmeida avatar Feb 22 '23 21:02 PedroMilanezAlmeida

Got it. FWIW, this should be one of the most memory-efficient platforms for single-cell integration.

adamgayoso avatar Feb 22 '23 22:02 adamgayoso

I'm going to close this issue as it's not clear that it's still a problem. The issue was created before our lightning integration and therefore any custom training code that could have caused a leak has been removed.

adamgayoso avatar Feb 28 '23 16:02 adamgayoso