contrastiveVI icon indicating copy to clipboard operation
contrastiveVI copied to clipboard

Issue during training of model

Open svntrx opened this issue 1 year ago • 9 comments

Dear contrastiveVI dev team,

First of all thanks a lot for developing this cool model and congrats on the Nature Methods paper on it! I was trying to apply it to a MIBI dataset harboring different drug treatment conditions, but unfortunately ran into an issue I can't seem to figure out myself.

I run the following code (taken from the Alzheimer example), wherein treated_control is an anndata file containing my single cell data and with "Drug" being the condition column.

# imports
from contrastive_vi.model import ContrastiveVI
from pytorch_lightning.utilities.seed import seed_everything

seed_everything(42) # For reproducibility

treated_control = treated_control.copy()
ContrastiveVI.setup_anndata(treated_control) # setup adata for use with this model

model = ContrastiveVI(
    treated_control,
    n_salient_latent=10,
    n_background_latent=10,
    use_observed_lib_size=False
)

background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]

model.train(
    check_val_every_n_epoch=1,
    train_size=0.8,
    background_indices=background_indices,
    target_indices=target_indices,
    use_gpu=False,
    early_stopping=True,
    max_epochs=500,
)

running model.train, I get the following error message:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[20], line 4
      1 background_indices = np.where(treated_control.obs["Drug"] == "CTRL")[0]
      2 target_indices = np.where(treated_control.obs["Drug"] != "CTRL")[0]
----> 4 model.train(
      5     check_val_every_n_epoch=1,
      6     train_size=0.8,
      7     background_indices=background_indices,
      8     target_indices=target_indices,
      9     use_gpu=False,
     10     early_stopping=True,
     11     max_epochs=500,
     12 )

File ~\Anaconda3\envs\ST0036\Lib\site-packages\contrastive_vi\model\base\training_mixin.py:88, in ContrastiveTrainingMixin.train(self, background_indices, target_indices, max_epochs, use_gpu, train_size, validation_size, batch_size, early_stopping, plan_kwargs, **trainer_kwargs)
     77 trainer_kwargs[es] = (
     78     early_stopping if es not in trainer_kwargs.keys() else trainer_kwargs[es]
     79 )
     80 runner = TrainRunner(
     81     self,
     82     training_plan=training_plan,
   (...)
     86     **trainer_kwargs,
     87 )
---> 88 return runner()

File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainrunner.py:74, in TrainRunner.__call__(self)
     71 if hasattr(self.data_splitter, "n_val"):
     72     self.training_plan.n_obs_validation = self.data_splitter.n_val
---> 74 self.trainer.fit(self.training_plan, self.data_splitter)
     75 self._update_history()
     77 # data splitter only gets these attrs after fit

File ~\Anaconda3\envs\ST0036\Lib\site-packages\scvi\train\_trainer.py:186, in Trainer.fit(self, *args, **kwargs)
    180 if isinstance(args[0], PyroTrainingPlan):
    181     warnings.filterwarnings(
    182         action="ignore",
    183         category=UserWarning,
    184         message="`LightningModule.configure_optimizers` returned `None`",
    185     )
--> 186 super().fit(*args, **kwargs)

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:740, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, train_dataloader, ckpt_path)
    735     rank_zero_deprecation(
    736         "`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."
    737         " Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'"
    738     )
    739     train_dataloaders = train_dataloader
--> 740 self._call_and_handle_interrupt(
    741     self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    742 )

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:685, in Trainer._call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    675 r"""
    676 Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict)
    677 as all errors should funnel through them
   (...)
    682     **kwargs: keyword arguments to be passed to `trainer_fn`
    683 """
    684 try:
--> 685     return trainer_fn(*args, **kwargs)
    686 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    687 except KeyboardInterrupt as exception:

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:777, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    775 # TODO: ckpt_path only in v1.7
    776 ckpt_path = ckpt_path or self.resume_from_checkpoint
--> 777 self._run(model, ckpt_path=ckpt_path)
    779 assert self.state.stopped
    780 self.training = False

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1138, in Trainer._run(self, model, ckpt_path)
   1136 self.call_hook("on_before_accelerator_backend_setup")
   1137 self.accelerator.setup_environment()
-> 1138 self._call_setup_hook()  # allow user to setup lightning_module in accelerator environment
   1140 # check if we should delay restoring checkpoint till later
   1141 if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch:

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\trainer\trainer.py:1438, in Trainer._call_setup_hook(self)
   1435 self.training_type_plugin.barrier("pre_setup")
   1437 if self.datamodule is not None:
-> 1438     self.datamodule.setup(stage=fn)
   1439 self.call_hook("setup", stage=fn)
   1441 self.training_type_plugin.barrier("post_setup")

File ~\Anaconda3\envs\ST0036\Lib\site-packages\pytorch_lightning\core\datamodule.py:461, in LightningDataModule._track_data_hook_calls.<locals>.wrapped_fn(*args, **kwargs)
    459     else:
    460         attr = f"_has_{name}_{stage}"
--> 461         has_run = getattr(obj, attr)
    462         setattr(obj, attr, True)
    464 elif name == "prepare_data":

AttributeError: 'ContrastiveDataSplitter' object has no attribute '_has_setup_TrainerFn.FITTING'

I ran the package in a fresh conda environment. Any ideas where the issue may lie?

Thanks a ton for your help!

Best regards, Sven

svntrx avatar Aug 16 '23 09:08 svntrx