scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

MULTIVI training fails before first epoch

Open MoritzTh opened this issue 3 years ago • 7 comments

Hi, I am trying to integrate mostly paired scRNA and scATAC data following your MULTIVI tutorial. Creating the mvi anndata and setting up the model with

scvi.model.MULTIVI.setup_anndata(adata_mvi, batch_key='modality')

works fine. However, training the model with:

mvi = scvi.model.MULTIVI(
    adata_mvi,
    n_genes=(adata_mvi.var['modality']=='Gene Expression').sum(),
    n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
)
mvi.train()

results in the error:

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Epoch 1/500:   0%|          | 0/500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-14-b4772eb1a555> in <module>
      4     n_regions=(adata_mvi.var['modality']=='Peaks').sum(),
      5 )
----> 6 mvi.train()

~/miniconda3/lib/python3.7/site-packages/scvi/model/_multivi.py in train(self, max_epochs, lr, use_gpu, train_size, validation_size, batch_size, weight_decay, eps, early_stopping, save_best, check_val_every_n_epoch, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_mixing, plan_kwargs, **kwargs)
    278             **kwargs,
    279         )
--> 280         return runner()
    281 
    282     @torch.no_grad()

~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainrunner.py in __call__(self)
     70             self.training_plan.n_obs_training = self.data_splitter.n_train
     71 
---> 72         self.trainer.fit(self.training_plan, self.data_splitter)
     73         self._update_history()
     74 

~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainer.py in fit(self, *args, **kwargs)
    175                     message="`LightningModule.configure_optimizers` returned `None`",
    176                 )
--> 177             super().fit(*args, **kwargs)

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloader, val_dataloaders, datamodule)
    458         )
    459 
--> 460         self._run(model)
    461 
    462         assert self.state.stopped

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in _run(self, model)
    756 
    757         # dispatch `start_training` or `start_evaluating` or `start_predicting`
--> 758         self.dispatch()
    759 
    760         # plugin will finalized fitting (e.g. ddp_spawn will load trained model)

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in dispatch(self)
    797             self.accelerator.start_predicting(self)
    798         else:
--> 799             self.accelerator.start_training(self)
    800 
    801     def run_stage(self):

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in start_training(self, trainer)
     94 
     95     def start_training(self, trainer: 'pl.Trainer') -> None:
---> 96         self.training_type_plugin.start_training(trainer)
     97 
     98     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in start_training(self, trainer)
    142     def start_training(self, trainer: 'pl.Trainer') -> None:
    143         # double dispatch to initiate the training loop
--> 144         self._results = trainer.run_stage()
    145 
    146     def start_evaluating(self, trainer: 'pl.Trainer') -> None:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_stage(self)
    807         if self.predicting:
    808             return self.run_predict()
--> 809         return self.run_train()
    810 
    811     def _pre_training_routine(self):

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py in run_train(self)
    869                 with self.profiler.profile("run_training_epoch"):
    870                     # run train epoch
--> 871                     self.train_loop.run_training_epoch()
    872 
    873                 if self.max_steps and self.max_steps <= self.global_step:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_epoch(self)
    497             # ------------------------------------
    498             with self.trainer.profiler.profile("run_training_batch"):
--> 499                 batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
    500 
    501             # when returning -1 from train_step, we end epoch early

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in run_training_batch(self, batch, batch_idx, dataloader_idx)
    736 
    737                         # optimizer step
--> 738                         self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    739                         if len(self.trainer.optimizers) > 1:
    740                             # revert back to previous state

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in optimizer_step(self, optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
    440             on_tpu=self.trainer._device_type == DeviceType.TPU and _TPU_AVAILABLE,
    441             using_native_amp=using_native_amp,
--> 442             using_lbfgs=is_lbfgs,
    443         )
    444 

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/lightning.py in optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, on_tpu, using_native_amp, using_lbfgs)
   1401 
   1402         """
-> 1403         optimizer.step(closure=optimizer_closure)
   1404 
   1405     def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int):

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py in step(self, closure, *args, **kwargs)
    212             profiler_name = f"optimizer_step_and_closure_{self._optimizer_idx}"
    213 
--> 214         self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
    215         self._total_optimizer_step_calls += 1
    216 

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/core/optimizer.py in __optimizer_step(self, closure, profiler_name, **kwargs)
    132 
    133         with trainer.profiler.profile(profiler_name):
--> 134             trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)
    135 
    136     def step(self, *args, closure: Optional[Callable] = None, **kwargs):

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in optimizer_step(self, optimizer, opt_idx, lambda_closure, **kwargs)
    327         )
    328         if make_optimizer_step:
--> 329             self.run_optimizer_step(optimizer, opt_idx, lambda_closure, **kwargs)
    330         self.precision_plugin.post_optimizer_step(optimizer, opt_idx)
    331         self.training_type_plugin.post_optimizer_step(optimizer, opt_idx, **kwargs)

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in run_optimizer_step(self, optimizer, optimizer_idx, lambda_closure, **kwargs)
    334         self, optimizer: Optimizer, optimizer_idx: int, lambda_closure: Callable, **kwargs: Any
    335     ) -> None:
--> 336         self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, **kwargs)
    337 
    338     def optimizer_zero_grad(self, current_epoch: int, batch_idx: int, optimizer: Optimizer, opt_idx: int) -> None:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in optimizer_step(self, optimizer, lambda_closure, **kwargs)
    191 
    192     def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
--> 193         optimizer.step(closure=lambda_closure, **kwargs)
    194 
    195     @property

~/miniconda3/lib/python3.7/site-packages/torch/optim/optimizer.py in wrapper(*args, **kwargs)
     86                 profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
     87                 with torch.autograd.profiler.record_function(profile_name):
---> 88                     return func(*args, **kwargs)
     89             return wrapper
     90 

~/miniconda3/lib/python3.7/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

~/miniconda3/lib/python3.7/site-packages/torch/optim/adam.py in step(self, closure)
     90         if closure is not None:
     91             with torch.enable_grad():
---> 92                 loss = closure()
     93 
     94         for group in self.param_groups:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in train_step_and_backward_closure()
    731                         def train_step_and_backward_closure():
    732                             result = self.training_step_and_backward(
--> 733                                 split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
    734                             )
    735                             return None if result is None else result.loss

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in training_step_and_backward(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
    821         with self.trainer.profiler.profile("training_step_and_backward"):
    822             # lightning module hook
--> 823             result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
    824             self._curr_step_result = result
    825 

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py in training_step(self, split_batch, batch_idx, opt_idx, hiddens)
    288             model_ref._results = Result()
    289             with self.trainer.profiler.profile("training_step"):
--> 290                 training_step_output = self.trainer.accelerator.training_step(args)
    291                 self.trainer.accelerator.post_training_step()
    292 

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py in training_step(self, args)
    202 
    203         with self.precision_plugin.train_step_context(), self.training_type_plugin.train_step_context():
--> 204             return self.training_type_plugin.training_step(*args)
    205 
    206     def post_training_step(self) -> None:

~/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py in training_step(self, *args, **kwargs)
    153 
    154     def training_step(self, *args, **kwargs):
--> 155         return self.lightning_module.training_step(*args, **kwargs)
    156 
    157     def post_training_step(self):

~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainingplans.py in training_step(self, batch, batch_idx, optimizer_idx)
    362             loss_kwargs = dict(kl_weight=self.kl_weight)
    363             inference_outputs, _, scvi_loss = self.forward(
--> 364                 batch, loss_kwargs=loss_kwargs
    365             )
    366             loss = scvi_loss.loss

~/miniconda3/lib/python3.7/site-packages/scvi/train/_trainingplans.py in forward(self, *args, **kwargs)
    145     def forward(self, *args, **kwargs):
    146         """Passthrough to `model.forward()`."""
--> 147         return self.module(*args, **kwargs)
    148 
    149     def training_step(self, batch, batch_idx, optimizer_idx=0):

~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if self.training:
---> 32             return fn(self, *args, **kwargs)
     33 
     34         device = list(set(p.device for p in self.parameters()))

~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_base_module.py in forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    143             tensors, **get_inference_input_kwargs
    144         )
--> 145         inference_outputs = self.inference(**inference_inputs, **inference_kwargs)
    146         generative_inputs = self._get_generative_input(
    147             tensors, inference_outputs, **get_generative_input_kwargs

~/miniconda3/lib/python3.7/site-packages/scvi/module/base/_decorators.py in auto_transfer_args(self, *args, **kwargs)
     30         # decorator only necessary after training
     31         if self.training:
---> 32             return fn(self, *args, **kwargs)
     33 
     34         device = list(set(p.device for p in self.parameters()))

~/miniconda3/lib/python3.7/site-packages/scvi/module/_multivae.py in inference(self, x, batch_index, cont_covs, cat_covs, n_samples)
    293         # Z Encoders
    294         qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
--> 295             encoder_input_accessibility, batch_index, *categorical_input
    296         )
    297         qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(

~/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

~/miniconda3/lib/python3.7/site-packages/scvi/nn/_base_components.py in forward(self, x, *cat_list)
    292         q_m = self.mean_encoder(q)
    293         q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 294         latent = self.z_transformation(reparameterize_gaussian(q_m, q_v))
    295         return q_m, q_v, latent
    296 

~/miniconda3/lib/python3.7/site-packages/scvi/nn/_base_components.py in reparameterize_gaussian(mu, var)
     11 
     12 def reparameterize_gaussian(mu, var):
---> 13     return Normal(mu, var.sqrt()).rsample()
     14 
     15 

~/miniconda3/lib/python3.7/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
     48         else:
     49             batch_shape = self.loc.size()
---> 50         super(Normal, self).__init__(batch_shape, validate_args=validate_args)
     51 
     52     def expand(self, batch_shape, _instance=None):

~/miniconda3/lib/python3.7/site-packages/torch/distributions/distribution.py in __init__(self, batch_shape, event_shape, validate_args)
     54                 if not valid.all():
     55                     raise ValueError(
---> 56                         f"Expected parameter {param} "
     57                         f"({type(value).__name__} of shape {tuple(value.shape)}) "
     58                         f"of distribution {repr(self)} "

ValueError: Expected parameter loc (Tensor of shape (128, 19)) of distribution Normal(loc: torch.Size([128, 19]), scale: torch.Size([128, 19])) to satisfy the constraint Real(), but found invalid values:
tensor([[ 0.1109,  0.5803,  0.3902,  ..., -0.4835, -0.8638,  0.0870],
        [ 0.7019,  0.4671,  0.4204,  ...,  0.0102, -0.8212,  0.0126],
        [ 0.1285,  1.1512, -0.1905,  ..., -0.4889, -0.0262,  0.0351],
        ...,
        [-0.0526,  0.3792,  0.6689,  ...,  0.2085,  0.0496,  0.4914],
        [ 0.5148,  0.4604,  0.4606,  ..., -0.0603, -0.3616,  0.4082],
        [ 0.5613,  0.6148,  0.5383,  ..., -0.1725, -1.2356, -0.1374]],
       grad_fn=<AddmmBackward0>)

Can you help me with this?

Thanks in advance! Moritz

Versions:

'0.14.5'

MoritzTh avatar Dec 14 '21 08:12 MoritzTh

Hi @MoritzTh can you provide more details on the shape of your dataset? Also as a first step I would try turning down the learning rate

adamgayoso avatar Dec 14 '21 22:12 adamgayoso

Hi, thanks for reaching out!

My final adata_mvi (10X Multiome data) consists of

  • 6242 cells (5752 paired, 281 expression, 209 accessibility)
  • 132384 features (120110 peaks, 12274 gene expression)

I tried using both raw, and normalized data as an input. Turning down the learning rate also didn't work...

MoritzTh avatar Dec 15 '21 14:12 MoritzTh

I would imagine the issue is related to preparation of the data. First, normalized data should not be used with this model. Second, can you print(adata_mvi.var)

adamgayoso avatar Dec 15 '21 16:12 adamgayoso

Sure, output is:

                                       gene_id     feature_type  chromosome  \
gene_symbol                                                                   
TNFRSF4                        ENSG00000186827  Gene Expression           1   
CDADC1                         ENSG00000102543  Gene Expression          13   
ABCC4                          ENSG00000125257  Gene Expression          13   
DIS3                           ENSG00000083520  Gene Expression          13   
COG6                           ENSG00000133103  Gene Expression          13   
...                                        ...              ...         ...   
15:98754122-98754998      15:98754122-98754998            Peaks          15   
15:98745417-98746162      15:98745417-98746162            Peaks          15   
15:98734498-98735342      15:98734498-98735342            Peaks          15   
15:98850647-98851428      15:98850647-98851428            Peaks          15   
KI270713.1:36892-37804  KI270713.1:36892-37804            Peaks  KI270713.1   

                        start_pos_tss  end_pos_tss         modality  n_cells  
gene_symbol                                                                   
TNFRSF4                       1214152      1214153  Gene Expression      156  
CDADC1                       49247924     49247925  Gene Expression     1254  
ABCC4                        95301318     95301451  Gene Expression     2225  
DIS3                         72781899     72782096  Gene Expression     1558  
COG6                         39655626     39655727  Gene Expression      822  
...                               ...          ...              ...      ...  
15:98754122-98754998         98754122     98754998            Peaks      183  
15:98745417-98746162         98745417     98746162            Peaks      180  
15:98734498-98735342         98734498     98735342            Peaks      295  
15:98850647-98851428         98850647     98851428            Peaks      773  
KI270713.1:36892-37804          36892        37804            Peaks      182  

MoritzTh avatar Dec 15 '21 17:12 MoritzTh

Can you try it in the case where you only have fully paired data? I'm concerned about the few unpaired cells you have at the moment.

adamgayoso avatar Dec 15 '21 17:12 adamgayoso

That works, thanks! Not sure why though...should work with not fully paired data aswell, right? Even if >90% of cells are paired?

MoritzTh avatar Dec 16 '21 09:12 MoritzTh

200 unpaired may just be too few.... @talashuach what do you think?

adamgayoso avatar Dec 16 '21 16:12 adamgayoso

closing this due to inactivity, we can debug further on our discourse forum

adamgayoso avatar Mar 06 '23 19:03 adamgayoso