scvi-tools
scvi-tools copied to clipboard
MULTIVI training fails before first epoch
Hi, I am trying to integrate mostly paired scRNA and scATAC data following your MULTIVI tutorial. Creating the mvi anndata and setting up the model with
scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality')
works fine. However, training the model with:
mvi = scvi.model.MULTIVI(
adata_mvi,
n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)
mvi.train()
results in the error:
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Epoch 1/500: 0%| | 0/500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-14-b4772eb1a555> in <module>
4 n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
5 )
----> 6 mvi.train()
~/miniconda3/lib/python3.7/site-packages/scvi/model/_multivi.py in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, weight_decay, eps, early_stopping, save_best, check_val_every_n_epoch, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_mixing, plan_kwargs, **kwargs)
278 **kwargs,
279 )
--> 280 return runner()
281
282 @torch.no_grad()
~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainrunner.py in __call__(self)
70 self.training_plan.n_obs_training = self.data_splitter.n_train
71
---> 72 self.trainer.fit(self.training_plan, self.data_splitter)
73 self._update_history()
74
~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
175 message="`LightningModule.configure_optimizers` returned `None`",
176 )
--> 177 super().fit(*args, **kwargs)
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
458 )
459
--> 460 self._run(model)
461
462 assert self.state.stopped
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
756
757 # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 758 self.dispatch()
759
760 # plugin will finalized fitting (e.g. ddp_spawn will load trained model)
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
797 self.accelerator.start_predicting(self)
798 else:
--> 799 self.accelerator.start_training(self)
800
801 def run_stage(self):
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
94
95 def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96 self.training_type_plugin.start_training(trainer)
97
98 def start_evaluating(self, trainer: 'pl.Trainer') -> None:
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
142 def start_training(self, trainer: 'pl.Trainer') -> None:
143 # double dispatch to initiate the training loop
--> 144 self._results = trainer.run_stage()
145
146 def start_evaluating(self, trainer: 'pl.Trainer') -> None:
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
807 if self.predicting:
808 return self.run_predict()
--> 809 return self.run_train()
810
811 def _pre_training_routine(self):
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
869 with self.profiler.profile("run_training_epoch"):
870 # run train epoch
--> 871 self.train_loop.run_training_epoch()
872
873 if self.max_steps and self.max_steps <= self.global_step:
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
497 # ------------------------------------
498 with self.trainer.profiler.profile("run_training_batch"):
--> 499 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
500
501 # when returning -1 from train_step, we end epoch early
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
736
737 # optimizer step
--> 738 self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
739 if len(self.trainer.optimizers) > 1:
740 # revert back to previous state
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
440 on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE,
441 using_native_amp=using_native_amp,
--> 442 using_lbfgs=is_lbfgs,
443 )
444
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
1401
1402 """
-> 1403 optimizer.step(closure=optimizer_closure)
1404
1405 def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
212 profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
213
--> 214 self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
215 self._total_optimizer_step_calls += 1
216
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
132
133 with trainer.profiler.profile(profiler_name):
--> 134 trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
135
136 def step(self, *args, closure: Optional[Callable] = None, **kwargs):
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
327 )
328 if make_optimizer_step:
--> 329 self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
330 self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
331 self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
334 self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
335 ) -> None:
--> 336 self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
337
338 def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
191
192 def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 193 optimizer.step(closure=lambda_closure, **kwargs)
194
195 @property
~/miniconda3/lib/python3.7/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
86 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
87 with torch.autograd.profiler.record_function(profile_name):
---> 88 return func(*args, **kwargs)
89 return wrapper
90
~/miniconda3/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
~/miniconda3/lib/python3.7/site-packages/torch/optim/adam.py in step(self, closure)
90 if closure is not None:
91 with torch.enable_grad():
---> 92 loss = closure()
93
94 for group in self.param_groups:
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
731 def train_step_and_backward_closure():
732 result = self.training_step_and_backward(
--> 733 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
734 )
735 return None if result is None else result.loss
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
821 with self.trainer.profiler.profile("training_step_and_backward"):
822 # lightning module hook
--> 823 result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
824 self._curr_step_result = result
825
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
288 model_ref._results = Result()
289 with self.trainer.profiler.profile("training_step"):
--> 290 training_step_output = self.trainer.accelerator.training_step(args)
291 self.trainer.accelerator.post_training_step()
292
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
202
203 with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 204 return self.training_type_plugin.training_step(*args)
205
206 def post_training_step(self) -> None:
~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
153
154 def training_step(self, *args, **kwargs):
--> 155 return self.lightning_module.training_step(*args, **kwargs)
156
157 def post_training_step(self):
~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
362 loss_kwargs = dict(kl_weight=self.kl_weight)
363 inference_outputs, _, scvi_loss = self.forward(
--> 364 batch, loss_kwargs=loss_kwargs
365 )
366 loss = scvi_loss.loss
~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainingplans.py in forward(self, *args, **kwargs)
145 def forward(self, *args, **kwargs):
146 """Passthrough to `model.forward()`."""
--> 147 return self.module(*args, **kwargs)
148
149 def training_step(self, batch, batch_idx, optimizer_idx=0):
~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
33
34 device = list(set(p.device for p in self.parameters()))
~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_base_module.py in forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
143 tensors, **get_inference_input_kwargs
144 )
--> 145 inference_outputs = self.inference(**inference_inputs, **inference_kwargs)
146 generative_inputs = self._get_generative_input(
147 tensors, inference_outputs, **get_generative_input_kwargs
~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
33
34 device = list(set(p.device for p in self.parameters()))
~/miniconda3/lib/python3.7/site-packages/scvi/module/_multivae.py in inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
293 # Z Encoders
294 qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
--> 295 encoder_input_accessibility, batch_index, *categorical_input
296 )
297 qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(
~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
~/miniconda3/lib/python3.7/site-packages/scvi/nn/_base_components.py in forward(self, x, *cat_list)
292 q_m = self.mean_encoder(q)
293 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 294 latent = self.z_transformation(reparameterize_gaussian(q_m, q_v))
295 return q_m, q_v, latent
296
~/miniconda3/lib/python3.7/site-packages/scvi/nn/_base_components.py in reparameterize_gaussian(mu, var)
11
12 def reparameterize_gaussian(mu, var):
---> 13 return Normal(mu, var.sqrt()).rsample()
14
15
~/miniconda3/lib/python3.7/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
48 else:
49 batch_shape = self.loc.size()
---> 50 super(Normal, self).__init__(batch_shape, validate_args=validate_args)
51
52 def expand(self, batch_shape, _instance=None):
~/miniconda3/lib/python3.7/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
54 if not valid.all():
55 raise ValueError(
---> 56 f"Expected parameter {param} "
57 f"({type(value).__name__} of shape {tuple(value.shape)}) "
58 f"of distribution {repr(self)} "
ValueError: Expected parameter loc (Tensor of shape (128, 19)) of distribution Normal(loc: torch.Size([128, 19]), scale: torch.Size([128, 19])) to satisfy the constraint Real(), but found invalid values:
tensor([[ 0.1109, 0.5803, 0.3902, ..., -0.4835, -0.8638, 0.0870],
[ 0.7019, 0.4671, 0.4204, ..., 0.0102, -0.8212, 0.0126],
[ 0.1285, 1.1512, -0.1905, ..., -0.4889, -0.0262, 0.0351],
...,
[-0.0526, 0.3792, 0.6689, ..., 0.2085, 0.0496, 0.4914],
[ 0.5148, 0.4604, 0.4606, ..., -0.0603, -0.3616, 0.4082],
[ 0.5613, 0.6148, 0.5383, ..., -0.1725, -1.2356, -0.1374]],
grad_fn=<AddmmBackward0>)
Can you help me with this?
Thanks in advance! Moritz
Versions:
'0.14.5'
Hi @MoritzTh can you provide more details on the shape of your dataset? Also as a first step I would try turning down the learning rate
Hi, thanks for reaching out!
My final adata_mvi (10X Multiome data) consists of
- 6242 cells (5752 paired, 281 expression, 209 accessibility)
- 132384 features (120110 peaks, 12274 gene expression)
I tried using both raw, and normalized data as an input. Turning down the learning rate also didn't work...
I would imagine the issue is related to preparation of the data. First, normalized data should not be used with this model. Second, can you print(adata_mvi.var)
Sure, output is:
gene_id feature_type chromosome \
gene_symbol
TNFRSF4 ENSG00000186827 Gene Expression 1
CDADC1 ENSG00000102543 Gene Expression 13
ABCC4 ENSG00000125257 Gene Expression 13
DIS3 ENSG00000083520 Gene Expression 13
COG6 ENSG00000133103 Gene Expression 13
... ... ... ...
15:98754122-98754998 15:98754122-98754998 Peaks 15
15:98745417-98746162 15:98745417-98746162 Peaks 15
15:98734498-98735342 15:98734498-98735342 Peaks 15
15:98850647-98851428 15:98850647-98851428 Peaks 15
KI270713.1:36892-37804 KI270713.1:36892-37804 Peaks KI270713.1
start_pos_tss end_pos_tss modality n_cells
gene_symbol
TNFRSF4 1214152 1214153 Gene Expression 156
CDADC1 49247924 49247925 Gene Expression 1254
ABCC4 95301318 95301451 Gene Expression 2225
DIS3 72781899 72782096 Gene Expression 1558
COG6 39655626 39655727 Gene Expression 822
... ... ... ... ...
15:98754122-98754998 98754122 98754998 Peaks 183
15:98745417-98746162 98745417 98746162 Peaks 180
15:98734498-98735342 98734498 98735342 Peaks 295
15:98850647-98851428 98850647 98851428 Peaks 773
KI270713.1:36892-37804 36892 37804 Peaks 182
Can you try it in the case where you only have fully paired data? I'm concerned about the few unpaired cells you have at the moment.
That works, thanks! Not sure why though...should work with not fully paired data aswell, right? Even if >90% of cells are paired?
200 unpaired may just be too few.... @talashuach what do you think?
closing this due to inactivity, we can debug further on our discourse forum