scvi-tools
scvi-tools copied to clipboard
GPU memory overflow (leak?) for large datasets
We love scVI and use it all the time for batch-effect correction / dataset harmonization. As we scale to larger and larger datasets, we seem to have hit a wall with scVI's scalability. Datasets with more than ~800k cells (using 2000 highly variable genes) never seem to complete training. I seem to be getting out-of-memory errors from the GPU memory late in training. This is strange, since I would not expect GPU memory usage to grow during training itself.
An error only occurs if you have a dataset with maybe 800k cells or more. Here is a sketch of the script that I run (not the whole thing) to do batch-effect correction:
import pandas as pd
import numpy as np
import scanpy as sc
import scvi
from scvi.dataset import Dataset10X
from scvi.dataset.anndataset import AnnDatasetFromAnnData
from scvi.models import VAE
from scvi.inference import UnsupervisedTrainer
# load h5ad
adata = sc.read_h5ad(filename=args.file)
# batch info
batch = np.array(adata.obs['scvi_batch'].values, dtype=np.double).flatten()
# ensure the batch numbers are in order starting at zero
lookup = dict(zip(np.unique(batch), np.arange(0, np.unique(batch).size)))
batch = np.array([lookup[b] for b in batch], dtype=np.double)
# subset adata_hvg to highly variable genes
adata_hvg = adata.copy()
adata_hvg._inplace_subset_var(adata.var['highly_variable'].values)
# create scvi Dataset object
scvi.set_verbosity('WARNING')
scvi_dataset = AnnDatasetFromAnnData(adata_hvg)
# training params
n_epochs = args.max_epochs
lr = 1e-3
use_cuda = True
# train the model and output model likelihood every epoch
vae = VAE(scvi_dataset.nb_genes,
n_batch=scvi_dataset.n_batches,
n_latent=args.latent_dim,
dispersion=dispersion)
trainer = UnsupervisedTrainer(vae,
scvi_dataset,
train_size=0.9,
use_cuda=use_cuda,
frequency=1,
early_stopping_kwargs={'early_stopping_metric': 'elbo',
'save_best_state_metric': 'elbo',
'patience': 20,
'threshold': 0.5})
print('Running scVI...')
trainer.train(n_epochs=n_epochs, lr=lr)
print('scVI complete.')
# obtain latents from scvi
posterior = trainer.create_posterior(trainer.model, scvi_dataset, indices=np.arange(len(scvi_dataset))).sequential()
latent, _, _ = posterior.get_latent() # https://github.com/YosefLab/scVI/blob/master/tests/notebooks/scanpy_pbmc3k.ipynb
# store scvi latents as part of the AnnData object
adata.obsm['X_scvi'] = latent
Below is an example of the output from a log file. I run scVI with a maximum of 300 epochs, but use early stopping, so it can terminate before that. In this example, we get to epoch 142 and the process is then killed. Because this actually happens during training, and I do not get the scVI complete. message from my script, I know that the memory error was GPU memory and not CPU memory after training had completed.
Linux error: Killed
Starting workflow.
Working on file.h5ad
scvi_script.py --file file.h5ad --latent_dim 50 --umap --max_epochs 300
Dataset for scvi:
GeneExpressionDataset object with n_cells x nb_genes = 729211 x 2364
gene_attribute_names: 'highly_variable', 'gene_names'
cell_attribute_names: 'local_means', 'batch_indices', 'scvi_batch', 'local_vars', 'labels'
cell_categorical_attribute_names: 'batch_indices', 'labels'
Running scVI...
training: 0%| | 0/300 [00:00<?, ?it/s] training: 0%| | 1/300 [02:18<11:30:33, 138.58s/it] training: 1%| | 2/300 [04:37<11:29:01, 138.73s/it] training: 1%| | 3/300 [06:54<11:23:30, 138.08s/it] training: 1%|▏ | 4/300 [09:13<11:22:59, 138.45s/it] training: 2%|▏ | 5/300 [11:32<11:21:48, 138.67s/it] training: 2%|▏ | 6/300 [13:49<11:16:16, 138.02s/it] training: 2%|▏ | 7/300 [16:03<11:08:58, 136.99s/it] training: 3%|▎ | 8/300 [18:21<11:07:15, 137.11s/it] training: 3%|▎ | 9/300 [20:36<11:02:53, 136.68s/it] training: 3%|▎ | 10/300 [22:55<11:03:29, 137.28s/it] training: 4%|▎ | 11/300 [25:10<10:57:34, 136.52s/it] training: 4%|▍ | 12/300 [27:28<10:57:01, 136.88s/it] training: 4%|▍ | 13/300 [29:45<10:56:17, 137.20s/it] training: 5%|▍ | 14/300 [32:04<10:55:22, 137.49s/it] training: 5%|▌ | 15/300 [34:19<10:50:05, 136.86s/it] training: 5%|▌ | 16/300 [36:37<10:49:39, 137.25s/it] training: 6%|▌ | 17/300 [38:52<10:44:27, 136.63s/it] training: 6%|▌ | 18/300 [41:12<10:46:21, 137.52s/it] training: 6%|▋ | 19/300 [43:34<10:49:57, 138.78s/it] training: 7%|▋ | 20/300 [45:52<10:47:08, 138.67s/it] training: 7%|▋ | 21/300 [48:10<10:43:18, 138.35s/it] training: 7%|▋ | 22/300 [50:30<10:43:23, 138.86s/it] training: 8%|▊ | 23/300 [52:51<10:45:00, 139.71s/it] training: 8%|▊ | 24/300 [55:08<10:38:25, 138.79s/it] training: 8%|▊ | 25/300 [57:21<10:28:01, 137.02s/it] training: 9%|▊ | 26/300 [59:36<10:23:10, 136.46s/it] training: 9%|▉ | 27/300 [1:02:01<10:32:14, 138.95s/it] training: 9%|▉ | 28/300 [1:04:25<10:36:37, 140.43s/it] training: 10%|▉ | 29/300 [1:06:47<10:36:28, 140.92s/it] training: 10%|█ | 30/300 [1:09:08<10:34:11, 140.93s/it] training: 10%|█ | 31/300 [1:11:26<10:28:16, 140.14s/it] training: 11%|█ | 32/300 [1:13:42<10:20:53, 139.00s/it] training: 11%|█ | 33/300 [1:16:00<10:16:34, 138.56s/it] training: 11%|█▏ | 34/300 [1:18:21<10:17:45, 139.35s/it] training: 12%|█▏ | 35/300 [1:20:40<10:15:08, 139.28s/it] training: 12%|█▏ | 36/300 [1:22:54<10:05:51, 137.70s/it] training: 12%|█▏ | 37/300 [1:25:10<10:00:29, 136.99s/it] training: 13%|█▎ | 38/300 [1:27:29<10:00:43, 137.57s/it] training: 13%|█▎ | 39/300 [1:29:48<10:00:21, 138.01s/it] training: 13%|█▎ | 40/300 [1:32:02<9:52:49, 136.81s/it] training: 14%|█▎ | 41/300 [1:34:23<9:56:40, 138.22s/it] training: 14%|█▍ | 42/300 [1:36:43<9:56:13, 138.66s/it] training: 14%|█▍ | 43/300 [1:39:05<9:58:24, 139.71s/it] training: 15%|█▍ | 44/300 [1:41:22<9:53:13, 139.04s/it] training: 15%|█▌ | 45/300 [1:43:43<9:53:13, 139.58s/it] training: 15%|█▌ | 46/300 [1:45:58<9:44:24, 138.05s/it] training: 16%|█▌ | 47/300 [1:48:17<9:43:02, 138.27s/it] training: 16%|█▌ | 48/300 [1:50:36<9:42:05, 138.59s/it] training: 16%|█▋ | 49/300 [1:52:49<9:33:29, 137.09s/it] training: 17%|█▋ | 50/300 [1:55:05<9:29:46, 136.74s/it] training: 17%|█▋ | 51/300 [1:57:17<9:21:20, 135.26s/it] training: 17%|█▋ | 52/300 [1:59:31<9:16:58, 134.75s/it] training: 18%|█▊ | 53/300 [2:01:48<9:18:04, 135.57s/it] training: 18%|█▊ | 54/300 [2:04:03<9:14:46, 135.31s/it] training: 18%|█▊ | 55/300 [2:06:16<9:10:12, 134.74s/it] training: 19%|█▊ | 56/300 [2:08:32<9:08:30, 134.88s/it] training: 19%|█▉ | 57/300 [2:10:44<9:03:29, 134.20s/it] training: 19%|█▉ | 58/300 [2:12:59<9:02:09, 134.42s/it] training: 20%|█▉ | 59/300 [2:15:19<9:07:06, 136.21s/it] training: 20%|██ | 60/300 [2:17:36<9:05:39, 136.41s/it] training: 20%|██ | 61/300 [2:19:58<9:09:28, 137.94s/it] training: 21%|██ | 62/300 [2:22:12<9:02:45, 136.83s/it] training: 21%|██ | 63/300 [2:24:28<8:59:49, 136.66s/it] training: 21%|██▏ | 64/300 [2:26:45<8:57:32, 136.66s/it] training: 22%|██▏ | 65/300 [2:29:04<8:57:49, 137.32s/it] training: 22%|██▏ | 66/300 [2:31:20<8:53:34, 136.81s/it] training: 22%|██▏ | 67/300 [2:33:38<8:52:46, 137.20s/it] training: 23%|██▎ | 68/300 [2:35:55<8:50:23, 137.17s/it] training: 23%|██▎ | 69/300 [2:38:14<8:50:03, 137.68s/it] training: 23%|██▎ | 70/300 [2:40:29<8:45:01, 136.96s/it] training: 24%|██▎ | 71/300 [2:42:48<8:45:03, 137.57s/it] training: 24%|██▍ | 72/300 [2:45:06<8:42:49, 137.59s/it] training: 24%|██▍ | 73/300 [2:47:21<8:38:05, 136.94s/it] training: 25%|██▍ | 74/300 [2:49:45<8:43:46, 139.05s/it] training: 25%|██▌ | 75/300 [2:52:06<8:43:25, 139.58s/it] training: 25%|██▌ | 76/300 [2:54:19<8:33:57, 137.67s/it] training: 26%|██▌ | 77/300 [2:56:38<8:33:33, 138.18s/it] training: 26%|██▌ | 78/300 [2:59:00<8:35:36, 139.35s/it] training: 26%|██▋ | 79/300 [3:01:24<8:37:40, 140.54s/it] training: 27%|██▋ | 80/300 [3:03:37<8:26:52, 138.24s/it] training: 27%|██▋ | 81/300 [3:05:55<8:24:47, 138.30s/it] training: 27%|██▋ | 82/300 [3:08:12<8:21:31, 138.04s/it] training: 28%|██▊ | 83/300 [3:10:21<8:09:27, 135.33s/it] training: 28%|██▊ | 84/300 [3:12:40<8:10:50, 136.34s/it] training: 28%|██▊ | 85/300 [3:14:56<8:08:02, 136.20s/it] training: 29%|██▊ | 86/300 [3:17:12<8:05:12, 136.04s/it] training: 29%|██▉ | 87/300 [3:19:33<8:08:40, 137.66s/it] training: 29%|██▉ | 88/300 [3:21:55<8:10:41, 138.88s/it] training: 30%|██▉ | 89/300 [3:24:09<8:03:30, 137.49s/it] training: 30%|███ | 90/300 [3:26:28<8:03:08, 138.04s/it] training: 30%|███ | 91/300 [3:28:47<8:01:35, 138.26s/it] training: 31%|███ | 92/300 [3:31:03<7:56:34, 137.47s/it] training: 31%|███ | 93/300 [3:33:27<8:00:43, 139.34s/it] training: 31%|███▏ | 94/300 [3:35:41<7:53:45, 137.99s/it] training: 32%|███▏ | 95/300 [3:38:01<7:53:02, 138.45s/it] training: 32%|███▏ | 96/300 [3:40:14<7:45:39, 136.96s/it] training: 32%|███▏ | 97/300 [3:42:31<7:42:33, 136.72s/it] training: 33%|███▎ | 98/300 [3:44:47<7:39:48, 136.57s/it] training: 33%|███▎ | 99/300 [3:47:02<7:35:54, 136.09s/it] training: 33%|███▎ | 100/300 [3:49:16<7:31:22, 135.41s/it] training: 34%|███▎ | 101/300 [3:51:30<7:28:23, 135.19s/it] training: 34%|███▍ | 102/300 [3:53:49<7:29:44, 136.29s/it] training: 34%|███▍ | 103/300 [3:56:03<7:24:55, 135.51s/it] training: 35%|███▍ | 104/300 [3:58:21<7:25:24, 136.35s/it] training: 35%|███▌ | 105/300 [4:00:40<7:25:43, 137.15s/it] training: 35%|███▌ | 106/300 [4:02:59<7:24:53, 137.60s/it] training: 36%|███▌ | 107/300 [4:05:17<7:22:52, 137.68s/it] training: 36%|███▌ | 108/300 [4:07:41<7:26:49, 139.63s/it] training: 36%|███▋ | 109/300 [4:10:06<7:29:58, 141.35s/it] training: 37%|███▋ | 110/300 [4:12:28<7:28:18, 141.57s/it] training: 37%|███▋ | 111/300 [4:14:51<7:27:17, 142.00s/it] training: 37%|███▋ | 112/300 [4:17:08<7:19:51, 140.38s/it] training: 38%|███▊ | 113/300 [4:19:27<7:16:48, 140.15s/it] training: 38%|███▊ | 114/300 [4:21:46<7:13:19, 139.78s/it] training: 38%|███▊ | 115/300 [4:24:03<7:08:02, 138.83s/it] training: 39%|███▊ | 116/300 [4:26:18<7:01:55, 137.58s/it] training: 39%|███▉ | 117/300 [4:28:32<6:56:48, 136.66s/it] training: 39%|███▉ | 118/300 [4:30:53<6:58:04, 137.82s/it] training: 40%|███▉ | 119/300 [4:33:06<6:51:25, 136.38s/it] training: 40%|████ | 120/300 [4:35:24<6:50:51, 136.95s/it] training: 40%|████ | 121/300 [4:37:43<6:50:23, 137.56s/it] training: 41%|████ | 122/300 [4:39:59<6:46:48, 137.13s/it] training: 41%|████ | 123/300 [4:42:19<6:46:54, 137.94s/it] training: 41%|████▏ | 124/300 [4:44:38<6:45:21, 138.19s/it] training: 42%|████▏ | 125/300 [4:46:56<6:42:52, 138.13s/it] training: 42%|████▏ | 126/300 [4:49:09<6:36:17, 136.65s/it] training: 42%|████▏ | 127/300 [4:51:26<6:34:28, 136.81s/it] training: 43%|████▎ | 128/300 [4:53:42<6:31:11, 136.46s/it] training: 43%|████▎ | 129/300 [4:55:59<6:29:47, 136.77s/it] training: 43%|████▎ | 130/300 [4:58:17<6:28:32, 137.13s/it] training: 44%|████▎ | 131/300 [5:00:36<6:27:52, 137.71s/it] training: 44%|████▍ | 132/300 [5:02:52<6:24:07, 137.19s/it] training: 44%|████▍ | 133/300 [5:05:15<6:26:16, 138.78s/it] training: 45%|████▍ | 134/300 [5:07:38<6:27:32, 140.08s/it] training: 45%|████▌ | 135/300 [5:09:54<6:22:20, 139.04s/it] training: 45%|████▌ | 136/300 [5:12:16<6:21:55, 139.73s/it] training: 46%|████▌ | 137/300 [5:14:32<6:16:53, 138.74s/it] training: 46%|████▌ | 138/300 [5:16:54<6:17:22, 139.77s/it] training: 46%|████▋ | 139/300 [5:19:12<6:13:37, 139.24s/it] training: 47%|████▋ | 140/300 [5:21:25<6:05:56, 137.23s/it] training: 47%|████▋ | 141/300 [5:23:41<6:02:57, 136.97s/it] training: 47%|████▋ | 142/300 [5:25:56<5:58:37, 136.18s/it]
16 Killed python scvi_script.py --file file.h5ad --latent_dim 50 --umap --max_epochs 300
When the number of cells is 500k or 600k, this works just fine!
Versions:
Ubuntu 16.04 scVI 0.6.5 scanpy 1.5.1
Hardware:
Tested on Tesla K80 GPU with 12GB memory, and Tesla P100 GPU with 16GB memory. We see this same failure in both cases.
But the point I wanted to emphasize is that this is not something for which the answer is "more GPU memory"... I think there is some problem happening during training. Why should the memory usage increase during the course of training? Training in mini-batches, what does it matter if the dataset size is 400k or 900k?
Thank you for using scVI. This is a bit strange as GPU memory usage should be constant during training. Is it possible to run with the INFO statements? Maybe it has something to do with the early stopping?
The only other possibility I see is
https://github.com/YosefLab/scVI/blob/dabf88291d15f19d2290026ef9c35504dd20a3ce/scvi/inference/trainer.py#L214
where, self.current_loss should actually be equal to loss.item()
A tangential thought is that I'm surprised it takes at least 140 epochs with that many cells. I'd guess that the ELBO converges around ~20 epochs.
We'll continue to look into it.
Also, what version of pytorch are you using? I'd also see in the meantime if updating helps.
Hi @sjfleming, thanks for posting this issue.
I agree with Adam, indeed I think it's surprising the early stopping did not converge after much less than 50 iterations. We are looking into this, there might be some part of the computational graph that is not cleared and takes up memory. It should be some side effect we are not aware of. Let us keep in touch.
Thanks for your comments. The PyTorch version I'm running is
pytorch 1.5.0
I typically see that we don't reach convergence until out past 100 epochs. Perhaps my early stopping parameters are more stringent than what you usually use? Might be worth noting also that we often have a rather large number of "batches"... anywhere from 10 to 80. The learning curve does flatten out substantially by nearly epoch 50, but there is still enough of a downward trend that it doesn't trigger early stopping, so I just let it run.
Thanks again for looking into this.
Hi @sjfleming -- on our master branch we just merged our new backend which uses Pytorch lightning. You can view the docs using the latest version of the readthedocs page. I'd be surprised if this issue still exists.
Remark to self, it would be great to manually check if we can reproduce this bug, as scalability is an important feature of the codebase.
should we create a random anndata matrix of large size (1M per 20k) and try to run scvi overnight?
yes that's a good idea
@khalilouardini would you like to work on this issue? it should be a quick one! Please try to run scVI with either a simulated AnnData or with the 1.3M cells dataset from 10x genomics (that we should have access directly from the codebase) and check that the training complete without troubles
Well @khalilouardini you should monitor the memory used, both on a cpu run over time and gpu run over time. Thanks!
Ok I'm on it!
@watiss can you take a look into this? You can use this callback to the trainer:
https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.callbacks.gpu_stats_monitor.html
Any updates on this?
I'm not sure this is still an issue, are you running into this issue @PedroMilanezAlmeida ?
For what it's worth, I do suspect the original issue was resolved when scvi was moved to the new pytorch-lightning backend. I haven't personally seen the issue since then, though I admit I haven't tried to do any comprehensive testing.
@adamgayoso
I'm not sure this is still an issue, are you running into this issue @PedroMilanezAlmeida ?
I am not currently running into this issue but need to make a decision about which platform to use moving forward for several very large projects.
Got it. FWIW, this should be one of the most memory-efficient platforms for single-cell integration.
I'm going to close this issue as it's not clear that it's still a problem. The issue was created before our lightning integration and therefore any custom training code that could have caused a leak has been removed.