scvi-tools
scvi-tools copied to clipboard
CellAssign tensor size error
Hello,
I'm trying to use CellAssign as implemented within scvi-tools to map a query dataset onto my annotated reference. However, I am running into an error when trying to train the CellAssign model. I follow the tutorial in making sure library size is calculated, and subsetting the anndata object to only include genes in my marker gene list binary table. I have a list of 225 markers which I want to assign cells based on - but somehow there's a "tensor b" which has a dimension of 363 when I try to run the training?
I tried following the tutorial with the sample follicular lymphoma data and it worked, so I suppose it is not an issue with the package versions? Any help would be much appreciated.
CSFadata_Heming_T.obs.index = CSFadata_Heming_T.obs.index.astype('str')
CSFadata_Heming_T.var.index = CSFadata_Heming_T.var.index.astype('str')
CSFadata_Heming_T.var_names_make_unique()
CSFadata_Heming_T.obs_names_make_unique()
lib_size = CSFadata_Heming_T.X.sum(1)
CSFadata_Heming_T.obs["size_factor"] = lib_size / np.mean(lib_size)
bdata = CSFadata_Heming_T[:, Tcelltype_markers_CellAssign_filtered.index].copy()
CellAssign.setup_anndata(bdata, size_factor_key="size_factor")
integration_model = CellAssign(bdata, Tcelltype_markers_CellAssign_filtered)
integration_model.train()
RuntimeError: The size of tensor a (225) must match the size of tensor b (363) at non-singleton dimension 0
The full error is:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
File <timed eval>:1
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/external/cellassign/_model.py:232, in CellAssign.train(self, max_epochs, lr, accelerator, devices, train_size, validation_size, shuffle_set_split, batch_size, datasplitter_kwargs, plan_kwargs, early_stopping, early_stopping_patience, early_stopping_min_delta, **kwargs)
222 training_plan = TrainingPlan(self.module, **plan_kwargs)
223 runner = TrainRunner(
224 self,
225 training_plan=training_plan,
(...)
230 **kwargs,
231 )
--> 232 return runner()
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/train/_trainrunner.py:98, in TrainRunner.__call__(self)
95 if hasattr(self.data_splitter, "n_val"):
96 self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 98 self.trainer.fit(self.training_plan, self.data_splitter)
99 self._update_history()
101 # data splitter only gets these attrs after fit
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/train/_trainer.py:219, in Trainer.fit(self, *args, **kwargs)
213 if isinstance(args[0], PyroTrainingPlan):
214 warnings.filterwarnings(
215 action="ignore",
216 category=UserWarning,
217 message="`LightningModule.configure_optimizers` returned `None`",
218 )
--> 219 super().fit(*args, **kwargs)
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
542 self.state.status = TrainerStatus.RUNNING
543 self.training = True
--> 544 call._call_and_handle_interrupt(
545 self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
546 )
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
42 if trainer.strategy.launcher is not None:
43 return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44 return trainer_fn(*args, **kwargs)
46 except _TunerExitException:
47 _call_teardown_hook(trainer)
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
573 assert self.state.fn is not None
574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
575 self.state.fn,
576 ckpt_path,
577 model_provided=True,
578 model_connected=self.lightning_module is not None,
579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
582 assert self.state.stopped
583 self.training = False
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
984 self._signal_connector.register_signal_handlers()
986 # ----------------------------
987 # RUN THE TRAINER
988 # ----------------------------
--> 989 results = self._run_stage()
991 # ----------------------------
992 # POST-Training CLEAN UP
993 # ----------------------------
994 log.debug(f"{self.__class__.__name__}: trainer tearing down")
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
1033 self._run_sanity_check()
1034 with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035 self.fit_loop.run()
1036 return None
1037 raise RuntimeError(f"Unexpected state {self.state}")
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:202, in _FitLoop.run(self)
200 try:
201 self.on_advance_start()
--> 202 self.advance()
203 self.on_advance_end()
204 self._restarting = False
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:359, in _FitLoop.advance(self)
357 with self.trainer.profiler.profile("run_training_epoch"):
358 assert self._data_fetcher is not None
--> 359 self.epoch_loop.run(self._data_fetcher)
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:136, in _TrainingEpochLoop.run(self, data_fetcher)
134 while not self.done:
135 try:
--> 136 self.advance(data_fetcher)
137 self.on_advance_end(data_fetcher)
138 self._restarting = False
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py:240, in _TrainingEpochLoop.advance(self, data_fetcher)
237 with trainer.profiler.profile("run_training_batch"):
238 if trainer.lightning_module.automatic_optimization:
239 # in automatic optimization, there can only be one optimizer
--> 240 batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
241 else:
242 batch_output = self.manual_optimization.run(kwargs)
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:187, in _AutomaticOptimization.run(self, optimizer, batch_idx, kwargs)
180 closure()
182 # ------------------------------
183 # BACKWARD PASS
184 # ------------------------------
185 # gradient update with accumulated gradients
186 else:
--> 187 self._optimizer_step(batch_idx, closure)
189 result = closure.consume_result()
190 if result.loss is None:
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:265, in _AutomaticOptimization._optimizer_step(self, batch_idx, train_step_and_backward_closure)
262 self.optim_progress.optimizer.step.increment_ready()
264 # model hook
--> 265 call._call_lightning_module_hook(
266 trainer,
267 "optimizer_step",
268 trainer.current_epoch,
269 batch_idx,
270 optimizer,
271 train_step_and_backward_closure,
272 )
274 if not should_accumulate:
275 self.optim_progress.optimizer.step.increment_completed()
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:157, in _call_lightning_module_hook(trainer, hook_name, pl_module, *args, **kwargs)
154 pl_module._current_fx_name = hook_name
156 with trainer.profiler.profile(f"[LightningModule]{pl_module.__class__.__name__}.{hook_name}"):
--> 157 output = fn(*args, **kwargs)
159 # restore current_fx when nested context
160 pl_module._current_fx_name = prev_fx_name
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/core/module.py:1291, in LightningModule.optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure)
1252 def optimizer_step(
1253 self,
1254 epoch: int,
(...)
1257 optimizer_closure: Optional[Callable[[], Any]] = None,
1258 ) -> None:
1259 r"""Override this method to adjust the default way the :class:`~lightning.pytorch.trainer.trainer.Trainer` calls
1260 the optimizer.
1261
(...)
1289
1290 """
-> 1291 optimizer.step(closure=optimizer_closure)
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py:151, in LightningOptimizer.step(self, closure, **kwargs)
148 raise MisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")
150 assert self._strategy is not None
--> 151 step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
153 self._on_after_step()
155 return step_output
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:230, in Strategy.optimizer_step(self, optimizer, closure, model, **kwargs)
228 # TODO(fabric): remove assertion once strategy's optimizer_step typing is fixed
229 assert isinstance(model, pl.LightningModule)
--> 230 return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py:117, in Precision.optimizer_step(self, optimizer, model, closure, **kwargs)
115 """Hook to run the optimizer step."""
116 closure = partial(self._wrap_closure, model, optimizer, closure)
--> 117 return optimizer.step(closure=closure, **kwargs)
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/optim/optimizer.py:391, in Optimizer.profile_hook_step.<locals>.wrapper(*args, **kwargs)
386 else:
387 raise RuntimeError(
388 f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
389 )
--> 391 out = func(*args, **kwargs)
392 self._optimizer_step_code()
394 # call optimizer step post hooks
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/optim/optimizer.py:76, in _use_grad_for_differentiable.<locals>._use_grad(self, *args, **kwargs)
74 torch.set_grad_enabled(self.defaults['differentiable'])
75 torch._dynamo.graph_break()
---> 76 ret = func(self, *args, **kwargs)
77 finally:
78 torch._dynamo.graph_break()
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/optim/adam.py:148, in Adam.step(self, closure)
146 if closure is not None:
147 with torch.enable_grad():
--> 148 loss = closure()
150 for group in self.param_groups:
151 params_with_grad = []
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py:104, in Precision._wrap_closure(self, model, optimizer, closure)
91 def _wrap_closure(
92 self,
93 model: "pl.LightningModule",
94 optimizer: Optimizer,
95 closure: Callable[[], Any],
96 ) -> Any:
97 """This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``
98 hook is called.
99
(...)
102
103 """
--> 104 closure_result = closure()
105 self._after_closure(model, optimizer)
106 return closure_result
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:140, in Closure.__call__(self, *args, **kwargs)
139 def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
--> 140 self._result = self.closure(*args, **kwargs)
141 return self._result.loss
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:126, in Closure.closure(self, *args, **kwargs)
124 @torch.enable_grad()
125 def closure(self, *args: Any, **kwargs: Any) -> ClosureResult:
--> 126 step_output = self._step_fn()
128 if step_output.closure_loss is None:
129 self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:315, in _AutomaticOptimization._training_step(self, kwargs)
312 trainer = self.trainer
314 # manually capture logged metrics
--> 315 training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
316 self.trainer.strategy.post_training_step() # unused hook - call anyway for backward compatibility
318 return self.output_result_cls.from_training_step_output(training_step_output, trainer.accumulate_grad_batches)
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py:309, in _call_strategy_hook(trainer, hook_name, *args, **kwargs)
306 return None
308 with trainer.profiler.profile(f"[Strategy]{trainer.strategy.__class__.__name__}.{hook_name}"):
--> 309 output = fn(*args, **kwargs)
311 # restore current_fx when nested context
312 pl_module._current_fx_name = prev_fx_name
File /gpfs3/users/jknight/nin561/.local/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py:382, in Strategy.training_step(self, *args, **kwargs)
380 if self.model != self.lightning_module:
381 return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
--> 382 return self.lightning_module.training_step(*args, **kwargs)
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/train/_trainingplans.py:344, in TrainingPlan.training_step(self, batch, batch_idx)
342 self.loss_kwargs.update({"kl_weight": kl_weight})
343 self.log("kl_weight", kl_weight, on_step=True, on_epoch=False)
--> 344 _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
345 self.log(
346 "train_loss",
347 scvi_loss.loss,
(...)
350 sync_dist=self.use_sync_dist,
351 )
352 self.compute_and_log_metrics(scvi_loss, self.train_metrics, "train")
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/train/_trainingplans.py:278, in TrainingPlan.forward(self, *args, **kwargs)
276 def forward(self, *args, **kwargs):
277 """Passthrough to the module's forward method."""
--> 278 return self.module(*args, **kwargs)
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/module/base/_base_module.py:203, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
172 @auto_move_data
173 def forward(
174 self,
(...)
181 compute_loss=True,
182 ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, LossOutput]:
183 """Forward pass through the network.
184
185 Parameters
(...)
201 another return value.
202 """
--> 203 return _generic_forward(
204 self,
205 tensors,
206 inference_kwargs,
207 generative_kwargs,
208 loss_kwargs,
209 get_inference_input_kwargs,
210 get_generative_input_kwargs,
211 compute_loss,
212 )
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/module/base/_base_module.py:743, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
739 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
740 generative_inputs = module._get_generative_input(
741 tensors, inference_outputs, **get_generative_input_kwargs
742 )
--> 743 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)
744 if compute_loss:
745 losses = module.loss(tensors, inference_outputs, generative_outputs, **loss_kwargs)
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/module/base/_decorators.py:32, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
30 # decorator only necessary after training
31 if self.training:
---> 32 return fn(self, *args, **kwargs)
34 device = list({p.device for p in self.parameters()})
35 if len(device) > 1:
File /well/jknight/projects/bmrcAudit/toAudit/Andrew_Kwok/python/NMDA_2023/lib/python3.10/site-packages/scvi/external/cellassign/_module.py:167, in CellAssignModule.generative(self, x, size_factor, design_matrix)
165 # base gene expression
166 b_g_0 = self.b_g_0.unsqueeze(-1).expand(n_cells, self.n_genes, self.n_labels)
--> 167 delta_rho = delta * self.rho
168 delta_rho = delta_rho.expand(n_cells, self.n_genes, self.n_labels) # (n, g, c)
169 log_mu_ngc = base_mean + delta_rho + b_g_0
RuntimeError: The size of tensor a (225) must match the size of tensor b (363) at non-singleton dimension 0
Versions:
VERSION
Last run with scvi-tools version: 1.1.2
As a quick aside - the CellAssign tutorial doesn't seem to load the numpy library which is required for calculating the library size, but that's a very minor issue.