scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

CellAssign tensor size error

Open andrewjkwok opened this issue 1 month ago • 0 comments

Hello,

I'm trying to use CellAssign as implemented within scvi-tools to map a query dataset onto my annotated reference. However, I am running into an error when trying to train the CellAssign model. I follow the tutorial in making sure library size is calculated, and subsetting the anndata object to only include genes in my marker gene list binary table. I have a list of 225 markers which I want to assign cells based on - but somehow there's a "tensor b" which has a dimension of 363 when I try to run the training?

I tried following the tutorial with the sample follicular lymphoma data and it worked, so I suppose it is not an issue with the package versions? Any help would be much appreciated.

CSFadata_Heming_T.obs.index = CSFadata_Heming_T.obs.index.astype('str')
CSFadata_Heming_T.var.index = CSFadata_Heming_T.var.index.astype('str')
CSFadata_Heming_T.var_names_make_unique()
CSFadata_Heming_T.obs_names_make_unique()

lib_size = CSFadata_Heming_T.X.sum(1)
CSFadata_Heming_T.obs["size_factor"] = lib_size / np.mean(lib_size)

bdata = CSFadata_Heming_T[:, Tcelltype_markers_CellAssign_filtered.index].copy()

CellAssign.setup_anndata(bdata, size_factor_key="size_factor")
integration_model = CellAssign(bdata, Tcelltype_markers_CellAssign_filtered)
integration_model.train()
RuntimeError: The size of tensor a (225) must match the size of tensor b (363) at non-singleton dimension 0

The full error is:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File <timed eval>:1

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/external/cellassign/_model.py:232, in CellAssign.train(self, max_epochs, lr, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, datasplitter_kwargs, plan_kwargs, early_stopping, early_stopping_patience, early_stopping_min_delta, **kwargs)
    222 training_plan = TrainingPlan(self.module, **plan_kwargs)
    223 runner = TrainRunner(
    224     self,
    225     training_plan=training_plan,
   (...)
    230     **kwargs,
    231 )
--> 232 return runner()

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/train/_trainrunner.py:98, in TrainRunner.__call__(self)
     95 if hasattr(self.data_splitter, "n_val"):
     96     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 98 self.trainer.fit(self.training_plan, self.data_splitter)
     99 self._update_history()
    101 # data splitter only gets these attrs after fit

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/train/_trainer.py:219, in Trainer.fit(self, *args, **kwargs)
    213 if isinstance(args[0], PyroTrainingPlan):
    214     warnings.filterwarnings(
    215         action="ignore",
    216         category=UserWarning,
    217         message="`LightningModule.configure_optimizers` returned `None`",
    218     )
--> 219 super().fit(*args, **kwargs)

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542 self.state.status = TrainerStatus.RUNNING
    543 self.training = True
--> 544 call._call_and_handle_interrupt(
    545     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546 )

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    573 assert self.state.fn is not None
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped
    583 self.training = False

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
    984 self._signal_connector.register_signal_handlers()
    986 # ----------------------------
    987 # RUN THE TRAINER
    988 # ----------------------------
--> 989 results = self._run_stage()
    991 # ----------------------------
    992 # POST-Training CLEAN UP
    993 # ----------------------------
    994 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
   1033         self._run_sanity_check()
   1034     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035         self.fit_loop.run()
   1036     return None
   1037 raise RuntimeError(f"Unexpected state {self.state}")

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self)
    200 try:
    201     self.on_advance_start()
--> 202     self.advance()
    203     self.on_advance_end()
    204     self._restarting = False

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:359, in _FitLoop.advance(self)
    357 with self.trainer.profiler.profile("run_training_epoch"):
    358     assert self._data_fetcher is not None
--> 359     self.epoch_loop.run(self._data_fetcher)

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
    134 while not self.done:
    135     try:
--> 136         self.advance(data_fetcher)
    137         self.on_advance_end(data_fetcher)
    138         self._restarting = False

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
    237 with trainer.profiler.profile("run_training_batch"):
    238     if trainer.lightning_module.automatic_optimization:
    239         # in automatic optimization, there can only be one optimizer
--> 240         batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
    241     else:
    242         batch_output = self.manual_optimization.run(kwargs)

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
    180         closure()
    182 # ------------------------------
    183 # BACKWARD PASS
    184 # ------------------------------
    185 # gradient update with accumulated gradients
    186 else:
--> 187     self._optimizer_step(batch_idx, closure)
    189 result = closure.consume_result()
    190 if result.loss is None:

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
    262     self.optim_progress.optimizer.step.increment_ready()
    264 # model hook
--> 265 call._call_lightning_module_hook(
    266     trainer,
    267     "optimizer_step",
    268     trainer.current_epoch,
    269     batch_idx,
    270     optimizer,
    271     train_step_and_backward_closure,
    272 )
    274 if not should_accumulate:
    275     self.optim_progress.optimizer.step.increment_completed()

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
    154 pl_module._current_fx_name = hook_name
    156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157     output = fn(*args, **kwargs)
    159 # restore current_fx when nested context
    160 pl_module._current_fx_name = prev_fx_name

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
   1252 def optimizer_step(
   1253     self,
   1254     epoch: int,
   (...)
   1257     optimizer_closure: Optional[Callable[[], Any]] = None,
   1258 ) -> None:
   1259     r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
   1260     the optimizer.
   1261 
   (...)
   1289 
   1290     """
-> 1291     optimizer.step(closure=optimizer_closure)

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
    148     raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
    150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
    153 self._on_after_step()
    155 return step_output

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
    228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
    229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
    115 """Hook to run the optimizer step."""
    116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/optim/optimizer.py:391, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
    386         else:
    387             raise RuntimeError(
    388                 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
    389             )
--> 391 out = func(*args, **kwargs)
    392 self._optimizer_step_code()
    394 # call optimizer step post hooks

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
     74     torch.set_grad_enabled(self.defaults['differentiable'])
     75     torch._dynamo.graph_break()
---> 76     ret = func(self, *args, **kwargs)
     77 finally:
     78     torch._dynamo.graph_break()

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/optim/adam.py:148, in Adam.step(self, closure)
    146 if closure is not None:
    147     with torch.enable_grad():
--> 148         loss = closure()
    150 for group in self.param_groups:
    151     params_with_grad = []

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
     91 def _wrap_closure(
     92     self,
     93     model: "pl.LightningModule",
     94     optimizer: Optimizer,
     95     closure: Callable[[], Any],
     96 ) -> Any:
     97     """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
     98     hook is called.
     99 
   (...)
    102 
    103     """
--> 104     closure_result = closure()
    105     self._after_closure(model, optimizer)
    106     return closure_result

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
    139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140     self._result = self.closure(*args, **kwargs)
    141     return self._result.loss

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
    124 @torch.enable_grad()
    125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126     step_output = self._step_fn()
    128     if step_output.closure_loss is None:
    129         self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:315, in _AutomaticOptimization._training_step(self, kwargs)
    312 trainer = self.trainer
    314 # manually capture logged metrics
--> 315 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
    316 self.trainer.strategy.post_training_step()  # unused hook - call anyway for backward compatibility
    318 return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
    306     return None
    308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309     output = fn(*args, **kwargs)
    311 # restore current_fx when nested context
    312 pl_module._current_fx_name = prev_fx_name

File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:382, in Strategy.training_step(self, *args, **kwargs)
    380 if self.model != self.lightning_module:
    381     return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 382 return self.lightning_module.training_step(*args, **kwargs)

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/train/_trainingplans.py:344, in TrainingPlan.training_step(self, batch, batch_idx)
    342     self.loss_kwargs.update({"kl_weight": kl_weight})
    343     self.log("kl_weight", kl_weight, on_step=True, on_epoch=False)
--> 344 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
    345 self.log(
    346     "train_loss",
    347     scvi_loss.loss,
   (...)
    350     sync_dist=self.use_sync_dist,
    351 )
    352 self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/train/_trainingplans.py:278, in TrainingPlan.forward(self, *args, **kwargs)
    276 def forward(self, *args, **kwargs):
    277     """Passthrough to the module's forward method."""
--> 278     return self.module(*args, **kwargs)

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/module/base/_base_module.py:203, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    172 @auto_move_data
    173 def forward(
    174     self,
   (...)
    181     compute_loss=True,
    182 ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]:
    183     """Forward pass through the network.
    184 
    185     Parameters
   (...)
    201         another return value.
    202     """
--> 203     return _generic_forward(
    204         self,
    205         tensors,
    206         inference_kwargs,
    207         generative_kwargs,
    208         loss_kwargs,
    209         get_inference_input_kwargs,
    210         get_generative_input_kwargs,
    211         compute_loss,
    212     )

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/module/base/_base_module.py:743, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    739 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
    740 generative_inputs = module._get_generative_input(
    741     tensors, inference_outputs, **get_generative_input_kwargs
    742 )
--> 743 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)
    744 if compute_loss:
    745     losses = module.loss(tensors, inference_outputs, generative_outputs, **loss_kwargs)

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     30 # decorator only necessary after training
     31 if self.training:
---> 32     return fn(self, *args, **kwargs)
     34 device = list({p.device for p in self.parameters()})
     35 if len(device) > 1:

File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/external/cellassign/_module.py:167, in CellAssignModule.generative(self, x, size_factor, design_matrix)
    165 # base gene expression
    166 b_g_0 = self.b_g_0.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels)
--> 167 delta_rho = delta * self.rho
    168 delta_rho = delta_rho.expand(n_cells, self.n_genes, self.n_labels)  # (n, g, c)
    169 log_mu_ngc = base_mean + delta_rho + b_g_0

RuntimeError: The size of tensor a (225) must match the size of tensor b (363) at non-singleton dimension 0

Versions:

VERSION

Last run with scvi-tools version: 1.1.2

As a quick aside - the CellAssign tutorial doesn't seem to load the numpy library which is required for calculating the library size, but that's a very minor issue.

andrewjkwok avatar Jun 11 '24 08:06 andrewjkwok