scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

MultiVI Tensor shape invalid values nan

Open bhhlee opened this issue 1 year ago • 3 comments

I am trying to use multivi on a paired 10x multiome (rna+atac) dataset. I've taken the mudata object with raw counts and made it into an anndata object for MultiVI. However, when I get to training the model, it fails before the first epoch. I've seen several similar issues here, but none of those solutions seem to be working.

I have a 161598 cells x 195122 obs matrix with only paired data.

adata_paired = ad.concat([rna.copy().T, atac.copy().T]).T
adata_paired.obs = adata_paired.obs.join(rna.obs['sample'])
adata_paired.obs["modality"] = "paired"
CPU times: user 49.5 s, sys: 8.99 s, total: 58.5 s
Wall time: 58.5 s
AnnData object with n_obs × n_vars = 161598 × 195122
    obs: 'sample', 'modality'
    var: 'gene_ids', 'feature_types', 'genome', 'interval'
mvi = scvi.model.MULTIVI(


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Epoch 1/10:   0%|                                        | 0/10 [00:00<?, ?it/s]
ValueError                                Traceback (most recent call last)
File <timed eval>:1

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/model/, in MULTIVI.train(self, max_epochs, lr, use_gpu, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, weight_decay, eps, early_stopping, save_best, check_val_every_n_epoch, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_mixing, plan_kwargs, **kwargs)
    338 training_plan = self._training_plan_cls(self.module, **plan_kwargs)
    339 runner = self._train_runner_cls(
    340     self,
    341     training_plan=training_plan,
    351     **kwargs,
    352 )
--> 353 return runner()

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/train/, in TrainRunner.__call__(self)
     96 if hasattr(self.data_splitter, "n_val"):
     97     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 99, self.data_splitter)
    100 self._update_history()
    102 # data splitter only gets these attrs after fit

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/train/, in, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/, in, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    527 model = _maybe_unwrap_optimized(model)
    528 self.strategy._lightning_module = model
--> 529 call._call_and_handle_interrupt(
    530     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    531 )

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     40     if trainer.strategy.launcher is not None:
     41         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42     return trainer_fn(*args, **kwargs)
     44 except _TunerExitException:
     45     _call_teardown_hook(trainer)

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    558 self._data_connector.attach_data(
    559     model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
    560 )
    562 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    563     self.state.fn,
    564     ckpt_path,
    565     model_provided=True,
    566     model_connected=self.lightning_module is not None,
    567 )
--> 568 self._run(model, ckpt_path=ckpt_path)
    570 assert self.state.stopped
    571 = False

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/, in Trainer._run(self, model, ckpt_path)
    968 self._signal_connector.register_signal_handlers()
    970 # ----------------------------
    972 # ----------------------------
--> 973 results = self._run_stage()
    975 # ----------------------------
    976 # POST-Training CLEAN UP
    977 # ----------------------------
    978 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/, in Trainer._run_stage(self)
   1014         self._run_sanity_check()
   1015     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1016
   1017     return None
   1018 raise RuntimeError(f"Unexpected state {self.state}")

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/, in
    199 try:
    200     self.on_advance_start()
--> 201     self.advance()
    202     self.on_advance_end()
    203     self._restarting = False

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/, in _FitLoop.advance(self)
    352 self._data_fetcher.setup(combined_loader)
    353 with self.trainer.profiler.profile("run_training_epoch"):
--> 354

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/, in, data_fetcher)
    131 while not self.done:
    132     try:
--> 133         self.advance(data_fetcher)
    134         self.on_advance_end()
    135         self._restarting = False

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/, in _TrainingEpochLoop.advance(self, data_fetcher)
    218             batch_output =[0], kwargs)
    219         else:
--> 220             batch_output =
    222 self.batch_progress.increment_processed()
    224 # update non-plateau LR schedulers
    225 # update epoch-interval ones only when we are at the end of training epoch

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/, in, kwargs)
     88 self.on_run_start()
     89 with suppress(StopIteration):  # no loop to break at this level
---> 90     self.advance(kwargs)
     91 self._restarting = False
     92 return self.on_run_end()

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/, in _ManualOptimization.advance(self, kwargs)
    106 trainer = self.trainer
    108 # manually capture logged metrics
--> 109 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    110 del kwargs  # release the batch from memory
    111 self.trainer.strategy.post_training_step()

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    288     return None
    290 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 291     output = fn(*args, **kwargs)
    293 # restore current_fx when nested context
    294 pl_module._current_fx_name = prev_fx_name

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/strategies/, in Strategy.training_step(self, *args, **kwargs)
    365 with self.precision_plugin.train_step_context():
    366     assert isinstance(self.model, TrainingStep)
--> 367     return self.model.training_step(*args, **kwargs)

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/train/, in AdversarialTrainingPlan.training_step(self, batch, batch_idx)
    553 else:
    554     opt1, opt2 = opts
--> 556 inference_outputs, _, scvi_loss = self.forward(
    557     batch, loss_kwargs=self.loss_kwargs
    558 )
    559 z = inference_outputs["z"]
    560 loss = scvi_loss.loss

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/train/, in TrainingPlan.forward(self, *args, **kwargs)
    276 def forward(self, *args, **kwargs):
    277     """Passthrough to the module's forward method."""
--> 278     return self.module(*args, **kwargs)

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/nn/modules/, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/base/, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/base/, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    171 @auto_move_data
    172 def forward(
    173     self,
    183     | tuple[torch.Tensor, torch.Tensor, LossOutput]
    184 ):
    185     """Forward pass through the network.
    187     Parameters
    203         another return value.
    204     """
--> 205     return _generic_forward(
    206         self,
    207         tensors,
    208         inference_kwargs,
    209         generative_kwargs,
    210         loss_kwargs,
    211         get_inference_input_kwargs,
    212         get_generative_input_kwargs,
    213         compute_loss,
    214     )

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/base/, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    744 get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs)
    746 inference_inputs = module._get_inference_input(
    747     tensors, **get_inference_input_kwargs
    748 )
--> 749 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
    750 generative_inputs = module._get_generative_input(
    751     tensors, inference_outputs, **get_generative_input_kwargs
    752 )
    753 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/base/, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/, in MULTIVAE.inference(self, x, y, batch_index, cont_covs, cat_covs, label, cell_idx, n_samples)
    612     categorical_input = ()
    614 # Z Encoders
--> 615 qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
    616     encoder_input_accessibility, batch_index, *categorical_input
    617 )
    618 qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(
    619     encoder_input_expression, batch_index, *categorical_input
    620 )
    621 qzm_pro, qzv_pro, z_pro = self.z_encoder_protein(
    622     encoder_input_protein, batch_index, *categorical_input
    623 )

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/nn/modules/, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/nn/, in Encoder.forward(self, x, *cat_list)
    287 q_m = self.mean_encoder(q)
    288 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 289 dist = Normal(q_m, q_v.sqrt())
    290 latent = self.z_transformation(dist.rsample())
    291 if self.return_dist:

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/distributions/, in Normal.__init__(self, loc, scale, validate_args)
     54 else:
     55     batch_shape = self.loc.size()
---> 56 super().__init__(batch_shape, validate_args=validate_args)

File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/distributions/, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     60         valid = constraint.check(value)
     61         if not valid.all():
---> 62             raise ValueError(
     63                 f"Expected parameter {param} "
     64                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     65                 f"of distribution {repr(self)} "
     66                 f"to satisfy the constraint {repr(constraint)}, "
     67                 f"but found invalid values:\n{value}"
     68             )
     69 super().__init__()

ValueError: Expected parameter loc (Tensor of shape (128, 19)) of distribution Normal(loc: torch.Size([128, 19]), scale: torch.Size([128, 19])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)

Any help would be greatly appreciated. Thank you in advance!

bhhlee avatar Sep 13 '23 04:09 bhhlee

Hi, if you're able to, could you please share the data you're using?

martinkim0 avatar Oct 07 '23 00:10 martinkim0

I encountered the same error. The issue might be associated with the 'adversarial_mixing' argument in the 'train' method which is set to 'True' by default. In your scenario, as the dataset contains only one batch, there is no need to do adversarial training. Therefore, it is advisable to manually set 'adversarial_mixing' to 'False'. When 'adversarial_mixing' is set to 'True,' 'n_class – 1' equals 0. Division by zero leads to the occurrence of NaN.

Xinle-Deng avatar Feb 01 '24 14:02 Xinle-Deng

@martinkim0, you can reproduce the issue with the data used in the MultiVI tutorial:

import gzip
from pathlib import Path

import pooch

import scanpy as sc
import scvi

save_dir = tempfile.TemporaryDirectory()

def download_data(save_path: str, fname: str = "pbmc_10k"):
    data_paths = pooch.retrieve(

    for path in data_paths:
        with, "rb") as f_in:
            with open(path.replace(".gz", ""), "wb") as f_out:

    return str(Path(data_paths[0]).parent)

data_path = download_data(

# read multiomic data
adata =

sc.pp.filter_genes(adata, min_cells=int(adata.n_obs * 0.01))


model = scvi.model.MULTIVI(
    n_genes=(adata.var["modality"] == "Gene Expression").sum(),
    n_regions=(adata.var["modality"] == "Peaks").sum(),


I checked it with scvi==1.0.4, and @Xinle-Deng's suggestion seems to work so it might be as simple as adding a check to set adversatial_mixing=False.

WeilerP avatar Mar 21 '24 14:03 WeilerP