contrastiveVI
contrastiveVI copied to clipboard
Issue during training of model
Dear contrastiveVI dev team,
First of all thanks a lot for developing this cool model and congrats on the Nature Methods paper on it! I was trying to apply it to a MIBI dataset harboring different drug treatment conditions, but unfortunately ran into an issue I can't seem to figure out myself.
I run the following code (taken from the Alzheimer example), wherein treated_control is an anndata file containing my single cell data and with "Drug" being the condition column.
# imports
from contrastive_vi.model import ContrastiveVI
from pytorch_lightning.utilities.seed import seed_everything
seed_everything(42) # For reproducibility
treated_control = treated_control.copy()
ContrastiveVI.setup_anndata(treated_control) # setup adata for use with this model
model = ContrastiveVI(
treated_control,
n_salient_latent=10,
n_background_latent=10,
use_observed_lib_size=False
)
background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]
model.train(
check_val_every_n_epoch=1,
train_size=0.8,
background_indices=background_indices,
target_indices=target_indices,
use_gpu=False,
early_stopping=True,
max_epochs=500,
)
running model.train, I get the following error message:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[20], line 4
1 background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
2 target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]
----> 4 model.train(
5 check_val_every_n_epoch=1,
6 train_size=0.8,
7 background_indices=background_indices,
8 target_indices=target_indices,
9 use_gpu=False,
10 early_stopping=True,
11 max_epochs=500,
12 )
File ~\Anaconda3\envs\ST0036\Lib\site-packages\contrastive_vi\model\base\training_mixin.py:88, in ContrastiveTrainingMixin.train(self, background_indices, target_indices, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
77 trainer_kwargs[es] = (
78 early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
79 )
80 runner = TrainRunner(
81 self,
82 training_plan=training_plan,
(...)
86 **trainer_kwargs,
87 )
---> 88 return runner()
File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainrunner.py:74, in TrainRunner.__call__(self)
71 if hasattr(self.data_splitter, "n_val"):
72 self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 74 self.trainer.fit(self.training_plan, self.data_splitter)
75 self._update_history()
77 # data splitter only gets these attrs after fit
File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
180 if isinstance(args[0], PyroTrainingPlan):
181 warnings.filterwarnings(
182 action="ignore",
183 category=UserWarning,
184 message="`LightningModule.configure_optimizers` returned `None`",
185 )
--> 186 super().fit(*args, **kwargs)
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
735 rank_zero_deprecation(
736 "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
737 " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
738 )
739 train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
741 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
742 )
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
675 r"""
676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
677 as all errors should funnel through them
(...)
682 **kwargs: keyword arguments to be passed to `trainer_fn`
683 """
684 try:
--> 685 return trainer_fn(*args, **kwargs)
686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
687 except KeyboardInterrupt as exception:
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
775 # TODO: ckpt_path only in v1.7
776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
779 assert self.state.stopped
780 self.training = False
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1138, in Trainer._run(self, model, ckpt_path)
1136 self.call_hook("on_before_accelerator_backend_setup")
1137 self.accelerator.setup_environment()
-> 1138 self._call_setup_hook() # allow user to setup lightning_module in accelerator environment
1140 # check if we should delay restoring checkpoint till later
1141 if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1438, in Trainer._call_setup_hook(self)
1435 self.training_type_plugin.barrier("pre_setup")
1437 if self.datamodule is not None:
-> 1438 self.datamodule.setup(stage=fn)
1439 self.call_hook("setup", stage=fn)
1441 self.training_type_plugin.barrier("post_setup")
File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\core\datamodule.py:461, in LightningDataModule._track_data_hook_calls.<locals>.wrapped_fn(*args, **kwargs)
459 else:
460 attr = f"_has_{name}_{stage}"
--> 461 has_run = getattr(obj, attr)
462 setattr(obj, attr, True)
464 elif name == "prepare_data":
AttributeError: 'ContrastiveDataSplitter' object has no attribute '_has_setup_TrainerFn.FITTING'
I ran the package in a fresh conda environment. Any ideas where the issue may lie?
Thanks a ton for your help!
Best regards, Sven