scvi-tools
scvi-tools copied to clipboard
MultiVI Tensor shape invalid values nan
I am trying to use multivi on a paired 10x multiome (rna+atac) dataset. I've taken the mudata object with raw counts and made it into an anndata object for MultiVI. However, when I get to training the model, it fails before the first epoch. I've seen several similar issues here, but none of those solutions seem to be working.
I have a 161598 cells x 195122 obs matrix with only paired data.
adata_paired = ad.concat([rna.copy().T, atac.copy().T]).T
adata_paired.obs = adata_paired.obs.join(rna.obs['sample'])
adata_paired.obs["modality"] = "paired"
adata_paired
CPU times: user 49.5 s, sys: 8.99 s, total: 58.5 s
Wall time: 58.5 s
AnnData object with n_obs × n_vars = 161598 × 195122
obs: 'sample', 'modality'
var: 'gene_ids', 'feature_types', 'genome', 'interval'
mvi = scvi.model.MULTIVI(
adata_paired,
n_genes=n_genes,
n_regions=n_regions,
)
mvi.view_anndata_setup()
mvi.train(use_gpu=False)
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Epoch 1/10: 0%| | 0/10 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
File <timed eval>:1
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/model/_multivi.py:353, in MULTIVI.train(self, max_epochs, lr, use_gpu, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, weight_decay, eps, early_stopping, save_best, check_val_every_n_epoch, n_steps_kl_warmup, n_epochs_kl_warmup, adversarial_mixing, plan_kwargs, **kwargs)
338 training_plan = self._training_plan_cls(self.module, **plan_kwargs)
339 runner = self._train_runner_cls(
340 self,
341 training_plan=training_plan,
(...)
351 **kwargs,
352 )
--> 353 return runner()
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/train/_trainrunner.py:99, in TrainRunner.__call__(self)
96 if hasattr(self.data_splitter, "n_val"):
97 self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 99 self.trainer.fit(self.training_plan, self.data_splitter)
100 self._update_history()
102 # data splitter only gets these attrs after fit
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/train/_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
180 if isinstance(args[0], PyroTrainingPlan):
181 warnings.filterwarnings(
182 action="ignore",
183 category=UserWarning,
184 message="`LightningModule.configure_optimizers` returned `None`",
185 )
--> 186 super().fit(*args, **kwargs)
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:529, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
527 model = _maybe_unwrap_optimized(model)
528 self.strategy._lightning_module = model
--> 529 call._call_and_handle_interrupt(
530 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
531 )
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:42, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
40 if trainer.strategy.launcher is not None:
41 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 42 return trainer_fn(*args, **kwargs)
44 except _TunerExitException:
45 _call_teardown_hook(trainer)
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:568, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
558 self._data_connector.attach_data(
559 model, train_dataloaders=train_dataloaders, val_dataloaders=val_dataloaders, datamodule=datamodule
560 )
562 ckpt_path = self._checkpoint_connector._select_ckpt_path(
563 self.state.fn,
564 ckpt_path,
565 model_provided=True,
566 model_connected=self.lightning_module is not None,
567 )
--> 568 self._run(model, ckpt_path=ckpt_path)
570 assert self.state.stopped
571 self.training = False
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:973, in Trainer._run(self, model, ckpt_path)
968 self._signal_connector.register_signal_handlers()
970 # ----------------------------
971 # RUN THE TRAINER
972 # ----------------------------
--> 973 results = self._run_stage()
975 # ----------------------------
976 # POST-Training CLEAN UP
977 # ----------------------------
978 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/trainer.py:1016, in Trainer._run_stage(self)
1014 self._run_sanity_check()
1015 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1016 self.fit_loop.run()
1017 return None
1018 raise RuntimeError(f"Unexpected state {self.state}")
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:201, in _FitLoop.run(self)
199 try:
200 self.on_advance_start()
--> 201 self.advance()
202 self.on_advance_end()
203 self._restarting = False
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/fit_loop.py:354, in _FitLoop.advance(self)
352 self._data_fetcher.setup(combined_loader)
353 with self.trainer.profiler.profile("run_training_epoch"):
--> 354 self.epoch_loop.run(self._data_fetcher)
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:133, in _TrainingEpochLoop.run(self, data_fetcher)
131 while not self.done:
132 try:
--> 133 self.advance(data_fetcher)
134 self.on_advance_end()
135 self._restarting = False
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/training_epoch_loop.py:220, in _TrainingEpochLoop.advance(self, data_fetcher)
218 batch_output = self.automatic_optimization.run(trainer.optimizers[0], kwargs)
219 else:
--> 220 batch_output = self.manual_optimization.run(kwargs)
222 self.batch_progress.increment_processed()
224 # update non-plateau LR schedulers
225 # update epoch-interval ones only when we are at the end of training epoch
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/manual.py:90, in _ManualOptimization.run(self, kwargs)
88 self.on_run_start()
89 with suppress(StopIteration): # no loop to break at this level
---> 90 self.advance(kwargs)
91 self._restarting = False
92 return self.on_run_end()
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/loops/optimization/manual.py:109, in _ManualOptimization.advance(self, kwargs)
106 trainer = self.trainer
108 # manually capture logged metrics
--> 109 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
110 del kwargs # release the batch from memory
111 self.trainer.strategy.post_training_step()
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/trainer/call.py:291, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
288 return None
290 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 291 output = fn(*args, **kwargs)
293 # restore current_fx when nested context
294 pl_module._current_fx_name = prev_fx_name
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/lightning/pytorch/strategies/strategy.py:367, in Strategy.training_step(self, *args, **kwargs)
365 with self.precision_plugin.train_step_context():
366 assert isinstance(self.model, TrainingStep)
--> 367 return self.model.training_step(*args, **kwargs)
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/train/_trainingplans.py:556, in AdversarialTrainingPlan.training_step(self, batch, batch_idx)
553 else:
554 opt1, opt2 = opts
--> 556 inference_outputs, _, scvi_loss = self.forward(
557 batch, loss_kwargs=self.loss_kwargs
558 )
559 z = inference_outputs["z"]
560 loss = scvi_loss.loss
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/train/_trainingplans.py:278, in TrainingPlan.forward(self, *args, **kwargs)
276 def forward(self, *args, **kwargs):
277 """Passthrough to the module's forward method."""
--> 278 return self.module(*args, **kwargs)
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/base/_base_module.py:205, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
171 @auto_move_data
172 def forward(
173 self,
(...)
183 | tuple[torch.Tensor, torch.Tensor, LossOutput]
184 ):
185 """Forward pass through the network.
186
187 Parameters
(...)
203 another return value.
204 """
--> 205 return _generic_forward(
206 self,
207 tensors,
208 inference_kwargs,
209 generative_kwargs,
210 loss_kwargs,
211 get_inference_input_kwargs,
212 get_generative_input_kwargs,
213 compute_loss,
214 )
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/base/_base_module.py:749, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
744 get_generative_input_kwargs = _get_dict_if_none(get_generative_input_kwargs)
746 inference_inputs = module._get_inference_input(
747 tensors, **get_inference_input_kwargs
748 )
--> 749 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
750 generative_inputs = module._get_generative_input(
751 tensors, inference_outputs, **get_generative_input_kwargs
752 )
753 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/module/_multivae.py:615, in MULTIVAE.inference(self, x, y, batch_index, cont_covs, cat_covs, label, cell_idx, n_samples)
612 categorical_input = ()
614 # Z Encoders
--> 615 qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility(
616 encoder_input_accessibility, batch_index, *categorical_input
617 )
618 qzm_expr, qzv_expr, z_expr = self.z_encoder_expression(
619 encoder_input_expression, batch_index, *categorical_input
620 )
621 qzm_pro, qzv_pro, z_pro = self.z_encoder_protein(
622 encoder_input_protein, batch_index, *categorical_input
623 )
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/nn/_base_components.py:289, in Encoder.forward(self, x, *cat_list)
287 q_m = self.mean_encoder(q)
288 q_v = self.var_activation(self.var_encoder(q)) + self.var_eps
--> 289 dist = Normal(q_m, q_v.sqrt())
290 latent = self.z_transformation(dist.rsample())
291 if self.return_dist:
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/distributions/normal.py:56, in Normal.__init__(self, loc, scale, validate_args)
54 else:
55 batch_shape = self.loc.size()
---> 56 super().__init__(batch_shape, validate_args=validate_args)
File /data/bhlee/miniconda3/envs/scvi-env/lib/python3.9/site-packages/torch/distributions/distribution.py:62, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
60 valid = constraint.check(value)
61 if not valid.all():
---> 62 raise ValueError(
63 f"Expected parameter {param} "
64 f"({type(value).__name__} of shape {tuple(value.shape)}) "
65 f"of distribution {repr(self)} "
66 f"to satisfy the constraint {repr(constraint)}, "
67 f"but found invalid values:\n{value}"
68 )
69 super().__init__()
ValueError: Expected parameter loc (Tensor of shape (128, 19)) of distribution Normal(loc: torch.Size([128, 19]), scale: torch.Size([128, 19])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
...,
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan],
[nan, nan, nan, ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)
Any help would be greatly appreciated. Thank you in advance!
Hi, if you're able to, could you please share the data you're using?
I encountered the same error. The issue might be associated with the 'adversarial_mixing' argument in the 'train' method which is set to 'True' by default. In your scenario, as the dataset contains only one batch, there is no need to do adversarial training. Therefore, it is advisable to manually set 'adversarial_mixing' to 'False'. When 'adversarial_mixing' is set to 'True,' 'n_class – 1' equals 0. Division by zero leads to the occurrence of NaN.
@martinkim0, you can reproduce the issue with the data used in the MultiVI tutorial:
import gzip
from pathlib import Path
import pooch
import scanpy as sc
import scvi
save_dir = tempfile.TemporaryDirectory()
def download_data(save_path: str, fname: str = "pbmc_10k"):
data_paths = pooch.retrieve(
url="https://cf.10xgenomics.com/samples/cell-arc/2.0.0/pbmc_unsorted_10k/pbmc_unsorted_10k_filtered_feature_bc_matrix.tar.gz",
known_hash="872b0dba467d972aa498812a857677ca7cf69050d4f9762b2cd4753b2be694a1",
fname=fname,
path=save_path,
processor=pooch.Untar(),
progressbar=True,
)
data_paths.sort()
for path in data_paths:
with gzip.open(path, "rb") as f_in:
with open(path.replace(".gz", ""), "wb") as f_out:
f_out.write(f_in.read())
return str(Path(data_paths[0]).parent)
data_path = download_data(save_dir.name)
# read multiomic data
adata = scvi.data.read_10x_multiome(data_path)
adata.var_names_make_unique()
sc.pp.filter_genes(adata, min_cells=int(adata.n_obs * 0.01))
scvi.model.MULTIVI.setup_anndata(adata)
model = scvi.model.MULTIVI(
adata,
n_genes=(adata.var["modality"] == "Gene Expression").sum(),
n_regions=(adata.var["modality"] == "Peaks").sum(),
)
model.train()
I checked it with scvi==1.0.4
, and @Xinle-Deng's suggestion seems to work so it might be as simple as adding a check to set adversatial_mixing=False
.