scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

Solo model KeyError when using continuous_covariate_keys

Open Nicholas-Everetts opened this issue 1 year ago • 0 comments

I'm having an issue creating a Solo model from an scVI model that was supplied continuous covariate data (percent mitochondrial reads, via "continuous_covariate_keys" parameter). The scVI model itself trains without an issue and looks reasonable, but it seems that the percent.mito data can't be found when trying to validate the AnnData object.

I haven't tested if this bug occurs when using categorical_covariate_keys.

contin_keys = ["percent.mito"]
scvi.model.SCVI.setup_anndata(adata,
                              batch_key = "batch",
                              categorical_covariate_keys = None,
                              continuous_covariate_keys = contin_keys)
model = scvi.model.SCVI(adata,
                        n_latent = 20,
                        n_layers = 2,
                        gene_likelihood = "nb")
model.train(max_epochs = 400, train_size = 0.9)
print(adata.obs)
#Output of the above print statement
                          batch  percent.mito  _scvi_batch  _scvi_labels
AAACCCACAAACAGGC-1        0.0      5.254359            0             0
AAACCCACAATTGCTG-1        0.0      3.539685            0             0
AAACCCACAGTTGCGC-1        0.0      5.130513            0             0
AAACCCAGTCCACTTC-1        0.0      6.404687            0             0
AAACGAAAGAGTTCGG-1        0.0      6.707605            0             0
...                       ...           ...          ...           ...
TTTGTTGGTATGGTAA-1        5.0      4.456618            5             0
TTTGTTGGTTCAAAGA-1        5.0      3.795428            5             0
TTTGTTGGTTGCACGC-1        5.0      4.157337            5             0
TTTGTTGTCATGAAAG-1        5.0      2.413516            5             0
TTTGTTGTCCGAGAAG-1        5.0      8.514687            5             0
#Error occurs here
solo_model = scvi.external.SOLO.from_scvi_model(model, restrict_to_batch =  0)
print(solo_model)
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Input In [13], in <cell line: 1>()
----> 1 solo_model = scvi.external.SOLO.from_scvi_model(model, restrict_to_batch =  0)
      2 print(solo_model)

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/external/solo/_model.py:197, in SOLO.from_scvi_model(cls, scvi_model, adata, restrict_to_batch, doublet_ratio, **classifier_kwargs)
    195 f = io.StringIO()
    196 with redirect_stdout(f):
--> 197     doublet_latent_rep = scvi_model.get_latent_representation(doublet_adata)
    198     doublet_lib_size = scvi_model.get_latent_library_size(
    199         doublet_adata, give_mean=give_mean_lib
    200     )
    201     doublet_adata = AnnData(
    202         np.concatenate([doublet_latent_rep, np.log(doublet_lib_size)], axis=1)
    203     )

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/model/base/_vaemixin.py:156, in VAEMixin.get_latent_representation(self, adata, indices, give_mean, mc_samples, batch_size)
    129 r"""
    130 Return the latent representation for each cell.
    131 
   (...)
    152     Low-dimensional representation for each cell
    153 """
    154 self._check_if_trained(warn=False)
--> 156 adata = self._validate_anndata(adata)
    157 scdl = self._make_data_loader(
    158     adata=adata, indices=indices, batch_size=batch_size
    159 )
    160 latent = []

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/model/base/_base_model.py:412, in BaseModelClass._validate_anndata(self, adata, copy_if_view)
    406 if adata_manager is None:
    407     logger.info(
    408         "Input AnnData not setup with scvi-tools. "
    409         + "attempting to transfer AnnData setup"
    410     )
    411     self._register_manager_for_instance(
--> 412         self.adata_manager.transfer_fields(adata)
    413     )
    414 else:
    415     # Case where correct AnnDataManager is found, replay registration as necessary.
    416     adata_manager.validate()

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/_manager.py:214, in AnnDataManager.transfer_fields(self, adata_target, **kwargs)
    210 fields = self.fields
    211 new_adata_manager = self.__class__(
    212     fields=fields, setup_method_args=self._get_setup_method_args()
    213 )
--> 214 new_adata_manager.register_fields(adata_target, self._registry, **kwargs)
    215 return new_adata_manager

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/_manager.py:167, in AnnDataManager.register_fields(self, adata, source_registry, **transfer_kwargs)
    162 if not field.is_empty:
    163     # Transfer case: Source registry is used for validation and/or setup.
    164     if source_registry is not None:
    165         field_registry[
    166             _constants._STATE_REGISTRY_KEY
--> 167         ] = field.transfer_field(
    168             source_registry[_constants._FIELD_REGISTRIES_KEY][
    169                 field.registry_key
    170             ][_constants._STATE_REGISTRY_KEY],
    171             adata,
    172             **transfer_kwargs,
    173         )
    174     else:
    175         field_registry[
    176             _constants._STATE_REGISTRY_KEY
    177         ] = field.register_field(adata)

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/fields/_obsm_field.py:250, in NumericalJointObsField.transfer_field(self, state_registry, adata_target, **kwargs)
    243 def transfer_field(
    244     self,
    245     state_registry: dict,
    246     adata_target: AnnData,
    247     **kwargs,
    248 ) -> dict:
    249     super().transfer_field(state_registry, adata_target, **kwargs)
--> 250     return self.register_field(adata_target)

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/fields/_obsm_field.py:239, in NumericalJointObsField.register_field(self, adata)
    238 def register_field(self, adata: AnnData) -> dict:
--> 239     super().register_field(adata)
    240     self._combine_obs_fields(adata)
    241     return {self.COLUMNS_KEY: adata.obsm[self.attr_key].columns.to_numpy()}

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/fields/_base_field.py:67, in BaseAnnDataField.register_field(self, adata)
     56 @abstractmethod
     57 def register_field(self, adata: AnnOrMuData) -> dict:
     58     """
     59     Sets up the AnnData/MuData object and creates a mapping for scvi-tools models to use.
     60 
   (...)
     65         stored directly on the AnnData/MuData object.
     66     """
---> 67     self.validate_field(adata)
     68     return dict()

File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/fields/_obsm_field.py:198, in JointObsField.validate_field(self, adata)
    196 for obs_key in self._obs_keys:
    197     if obs_key not in adata.obs:
--> 198         raise KeyError(f"{obs_key} not found in adata.obs.")

KeyError: 'percent.mito not found in adata.obs.'

Versions:

VERSION 0.17.3

Nicholas-Everetts avatar Sep 15 '22 18:09 Nicholas-Everetts