scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

Adding cycle consistency loss and VampPrior to scVI

Open Hrovatin opened this issue 7 months ago • 15 comments

We were considering adding cycle-consistency loss and VampPrior as described in Integrating single-cell RNA-seq datasets with substantial batch effects | bioRxiv) to scvi-tools package, preferably directly into scVI.

The current implementation is based on a cVAE that works with normalized+log transformed data: https://github.com/theislab/cross_system_integration/blob/main/cross_system_integration/model/_xxjointmodel.py

Adding VampPrior would require changes in Model to initialise VampPrior pseudoinputs https://github.com/theislab/cross_system_integration/blob/322549224cddcd5d375f8b49cb9f0e0e77c2be1f/cross_system_integration/model/_xxjointmodel.py#L87C10-L87C10 and in module to replace prior: https://github.com/theislab/cross_system_integration/blob/322549224cddcd5d375f8b49cb9f0e0e77c2be1f/cross_system_integration/module/_xxjointmodule.py#L165C11-L165C11

Adding cycle consistency would require changes in adata setup to specify which batch covariate should be used as the "system" that is specifically corrected for with Lcyc - we specify systems and batches within systems (not corrected by Lcyc) for which batch_key and categorical/continous_covariate_keys could be used respectively. Besides, changes in the forward pass and an additional loss would need to be added, see this file: https://github.com/theislab/cross_system_integration/blob/multiple_sys_model/cross_system_integration/module/_xxjointmodule.py#L235 - This branch has slightly different implementation than in the paper, but I would suggest adding this one as it is more general.

Hrovatin avatar Jan 09 '24 07:01 Hrovatin