scvi-tools
scvi-tools copied to clipboard
Solo model KeyError when using continuous_covariate_keys
I'm having an issue creating a Solo model from an scVI model that was supplied continuous covariate data (percent mitochondrial reads, via "continuous_covariate_keys" parameter). The scVI model itself trains without an issue and looks reasonable, but it seems that the percent.mito data can't be found when trying to validate the AnnData object.
I haven't tested if this bug occurs when using categorical_covariate_keys.
contin_keys = ["percent.mito"]
scvi.model.SCVI.setup_anndata(adata,
batch_key = "batch",
categorical_covariate_keys = None,
continuous_covariate_keys = contin_keys)
model = scvi.model.SCVI(adata,
n_latent = 20,
n_layers = 2,
gene_likelihood = "nb")
model.train(max_epochs = 400, train_size = 0.9)
print(adata.obs)
#Output of the above print statement
batch percent.mito _scvi_batch _scvi_labels
AAACCCACAAACAGGC-1 0.0 5.254359 0 0
AAACCCACAATTGCTG-1 0.0 3.539685 0 0
AAACCCACAGTTGCGC-1 0.0 5.130513 0 0
AAACCCAGTCCACTTC-1 0.0 6.404687 0 0
AAACGAAAGAGTTCGG-1 0.0 6.707605 0 0
... ... ... ... ...
TTTGTTGGTATGGTAA-1 5.0 4.456618 5 0
TTTGTTGGTTCAAAGA-1 5.0 3.795428 5 0
TTTGTTGGTTGCACGC-1 5.0 4.157337 5 0
TTTGTTGTCATGAAAG-1 5.0 2.413516 5 0
TTTGTTGTCCGAGAAG-1 5.0 8.514687 5 0
#Error occurs here
solo_model = scvi.external.SOLO.from_scvi_model(model, restrict_to_batch = 0)
print(solo_model)
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Input In [13], in <cell line: 1>()
----> 1 solo_model = scvi.external.SOLO.from_scvi_model(model, restrict_to_batch = 0)
2 print(solo_model)
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/external/solo/_model.py:197, in SOLO.from_scvi_model(cls, scvi_model, adata, restrict_to_batch, doublet_ratio, **classifier_kwargs)
195 f = io.StringIO()
196 with redirect_stdout(f):
--> 197 doublet_latent_rep = scvi_model.get_latent_representation(doublet_adata)
198 doublet_lib_size = scvi_model.get_latent_library_size(
199 doublet_adata, give_mean=give_mean_lib
200 )
201 doublet_adata = AnnData(
202 np.concatenate([doublet_latent_rep, np.log(doublet_lib_size)], axis=1)
203 )
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/model/base/_vaemixin.py:156, in VAEMixin.get_latent_representation(self, adata, indices, give_mean, mc_samples, batch_size)
129 r"""
130 Return the latent representation for each cell.
131
(...)
152 Low-dimensional representation for each cell
153 """
154 self._check_if_trained(warn=False)
--> 156 adata = self._validate_anndata(adata)
157 scdl = self._make_data_loader(
158 adata=adata, indices=indices, batch_size=batch_size
159 )
160 latent = []
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/model/base/_base_model.py:412, in BaseModelClass._validate_anndata(self, adata, copy_if_view)
406 if adata_manager is None:
407 logger.info(
408 "Input AnnData not setup with scvi-tools. "
409 + "attempting to transfer AnnData setup"
410 )
411 self._register_manager_for_instance(
--> 412 self.adata_manager.transfer_fields(adata)
413 )
414 else:
415 # Case where correct AnnDataManager is found, replay registration as necessary.
416 adata_manager.validate()
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/_manager.py:214, in AnnDataManager.transfer_fields(self, adata_target, **kwargs)
210 fields = self.fields
211 new_adata_manager = self.__class__(
212 fields=fields, setup_method_args=self._get_setup_method_args()
213 )
--> 214 new_adata_manager.register_fields(adata_target, self._registry, **kwargs)
215 return new_adata_manager
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/_manager.py:167, in AnnDataManager.register_fields(self, adata, source_registry, **transfer_kwargs)
162 if not field.is_empty:
163 # Transfer case: Source registry is used for validation and/or setup.
164 if source_registry is not None:
165 field_registry[
166 _constants._STATE_REGISTRY_KEY
--> 167 ] = field.transfer_field(
168 source_registry[_constants._FIELD_REGISTRIES_KEY][
169 field.registry_key
170 ][_constants._STATE_REGISTRY_KEY],
171 adata,
172 **transfer_kwargs,
173 )
174 else:
175 field_registry[
176 _constants._STATE_REGISTRY_KEY
177 ] = field.register_field(adata)
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/fields/_obsm_field.py:250, in NumericalJointObsField.transfer_field(self, state_registry, adata_target, **kwargs)
243 def transfer_field(
244 self,
245 state_registry: dict,
246 adata_target: AnnData,
247 **kwargs,
248 ) -> dict:
249 super().transfer_field(state_registry, adata_target, **kwargs)
--> 250 return self.register_field(adata_target)
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/fields/_obsm_field.py:239, in NumericalJointObsField.register_field(self, adata)
238 def register_field(self, adata: AnnData) -> dict:
--> 239 super().register_field(adata)
240 self._combine_obs_fields(adata)
241 return {self.COLUMNS_KEY: adata.obsm[self.attr_key].columns.to_numpy()}
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/fields/_base_field.py:67, in BaseAnnDataField.register_field(self, adata)
56 @abstractmethod
57 def register_field(self, adata: AnnOrMuData) -> dict:
58 """
59 Sets up the AnnData/MuData object and creates a mapping for scvi-tools models to use.
60
(...)
65 stored directly on the AnnData/MuData object.
66 """
---> 67 self.validate_field(adata)
68 return dict()
File ~/anaconda3/envs/scvi-env/lib/python3.9/site-packages/scvi/data/fields/_obsm_field.py:198, in JointObsField.validate_field(self, adata)
196 for obs_key in self._obs_keys:
197 if obs_key not in adata.obs:
--> 198 raise KeyError(f"{obs_key} not found in adata.obs.")
KeyError: 'percent.mito not found in adata.obs.'
Versions:
VERSION 0.17.3