scvi-tools
scvi-tools copied to clipboard
Reference mapping with MuData-based models
(Addresses #2393) While certain multimodal models (e.g. totalVI) can be used with either AnnData or MuData objects, currently the ArchesMixin class only supports AnnData-based models and so reference mapping cannot be performed with MuData-based models. This PR thus extends ArchesMixin class to handle MuData-based models by making appropriate changes to the load_query_data function and adding a new prepare_query_mudata function.
@martinkim0 my initial PR is still a work-in-progress (e.g. no unit tests yet), but I thought it would be helpful to get some initial feedback from you. To avoid duplicating code, my instinct for the prepare_query_mudata function was to essentially be a lightweight wrapper that calls prepare_query_anndata for each modality in the MuData object. However, when attempting to do this naively I quickly ran into issues with var_names for MuData objects being saved as a single list for all modalities. In particular, because variable names don't have identifiers for their modality of origin, I ran into issues where e.g. a protein-modality-AnnData would erroneously be padded with zeros for "missing" RNA-modality features. To work around this, I modified the base save function such that variable names are saved as a dictionary (with modalities as keys) for MuData objects.
This worked (as in, my code ran without errors and current tests still pass), but I imagine this would lead to backwards incompatibility issues for previously trained MuData-based models. Moreover, I still needed a way to tell setup_anndata which specific variable names to consider ahead of time, which seemed a bit ugly to me. Do you have ideas for a cleaner way to associate features with modalities?
@martinkim0 with the closing of #2769, I've gone ahead and taken another stab at this PR.
The biggest change now compared to main is that I factored out most of the old code in prepare_query_anndata into a new function _pad_and_sort_query_anndata that's reused across prepare_query_anndata and a new prepare_query_mudata method.
Let me know what you think/if there's anything else you'd like in this PR in terms of testing/documentation/etc.
Codecov Report
Attention: Patch coverage is 88.52459% with 7 lines in your changes missing coverage. Please review.
Project coverage is 85.21%. Comparing base (
7377065) to head (295898a).
Additional details and impacted files
@@ Coverage Diff @@
## main #2578 +/- ##
=======================================
Coverage 85.21% 85.21%
=======================================
Files 166 166
Lines 14202 14236 +34
=======================================
+ Hits 12102 12131 +29
- Misses 2100 2105 +5
| Files | Coverage Δ | |
|---|---|---|
| src/scvi/model/base/_archesmixin.py | 92.16% <88.52%> (-1.78%) |
:arrow_down: |
Hi @martinkim0. Just wanted to check in to see if you had any feedback for the latest version of this PR. Thanks!
Thanks a lot! This looks great - I think refactoring out
_pad_and_sort_query_anndatamakes a lot of sense. Just left some minor comments regarding typing and getting var names.One last thing: could you add a test for
prepare_query_mudataas well as a release note inCHANGELOG.md(this can go under the v1.2 release).
Added the requested typing + incorporated your suggestion on ._get_var_names. For testing I added a new test test_scarches_mudata_prep_layer in test_totalvi.py analogous to test_scarches_data_prep_layer in test_scvi.py.