scvi-tools icon indicating copy to clipboard operation
scvi-tools copied to clipboard

[WIP] lvm-DE code

Open PierreBoyeau opened this issue 4 years ago • 1 comments

Fixes #1068 Here is the code we discussed for lvm-DE, an importance-sampling-based procedure for calibrated DE gene predictions. A parallel PR will be opened asap for the associated notebook introducing the module. Depending on your call on the place of this tool in scvi-tools, it could require some additional work (see questions).

Presentation of the PR

New files

  • external/wscvi: In line with existing external packages, the wscvi folder contains the code to run an SCVI version that easily yields variational distributions, priors densities, as well as likelihoods. Having access to these quantities is used for two reasons. First, lvm-DE relies on importance-sampling, and hence directly requires these quantities. A second reason is that this allows training scVI models with IWELBO which arguably improves the stability of lvm-DE.

  • model/base/_demixin.py contains the lvm-DE core. I tried as much as possible to document it, but it requires additional documentation. In particular, the lvm_de method applies the same mechanisms as differential_expression in _rnamixin.py. It also uses DifferentialComputation under the hood.

Modified files

The main modification I made was to make DifferentialComputation usable with lvm-DE. To get a given number of gene normalized expression samples for a selection of indices, DifferentialComputation randomly samples cell indices, using

    idx = np.random.choice(np.arange(self.adata.shape[0])[selection], n_samples)

Next, we get one sample for each index using the variational distribution directly. This reasoning does not work well for lvm-DE.

Consequently, I modified DifferentialComputation such that it lets model_fn take care of the sampling instead, through n_samples_overall. For instance, when using the current differential tools, model_fn now takes a n_samples_overall parameter, which takes care of calling np.random.choice the same way as before. While this is a small change, it required me to change each so-mentioned function in the codebase (in _peakvi.py, _scvi.py, and _totalvi.py). I also induced code duplication, which I am not fond of but will try to remove.

I also allowed the get_marginal_ll methods to return observation-specific marginal densities as an option, as lvm-DE requires these densities.

Questions

Here are some uncertainties I have on what we should do with this code.

  1. Should we somehow "merge" WSCVI and SCVI? In other words, should we provide an option n_particles to SCVI to switch from a mode computing KL divergences to a one relying on log-ratios to train the model? An easy first step in this direction consists in returning torch.Distribution in inference and generative. From there, we have two options. First, we could let the loss function decide how to compute the loss (using kl_divergence when possible and densities otherwise). In that scenario, we would have a somehow unified VAE/WVAE model. This could however make the module heavier, and I don't know if this is a good idea. The second option, more in phase with what WVAE does, would be to keep the log-ratio-friendly model on the side. In any case, I think that a WVAE inspired structure can help developers in certain ways that go beyond DE. Indeed, it can easily be changed to consider alternatives inference networks (e.g., using flows) and priors for z (hyperspherical, vamp priors, etc.), or maybe more easily consider hierarchical models.

  2. Where should we keep WSCVI? I think the way to go is to keep it as an external package, but what do you think?

PierreBoyeau avatar May 19 '21 12:05 PierreBoyeau

Codecov Report

Base: 90.41% // Head: 90.99% // Increases project coverage by +0.57% :tada:

Coverage data is based on head (7b64282) compared to base (47bf4a5). Patch coverage: 92.44% of modified lines in pull request are covered.

:exclamation: Current head 7b64282 differs from pull request most recent head 05eced2. Consider uploading reports for the commit 05eced2 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1067      +/-   ##
==========================================
+ Coverage   90.41%   90.99%   +0.57%     
==========================================
  Files         141      115      -26     
  Lines       11099     9014    -2085     
==========================================
- Hits        10035     8202    -1833     
+ Misses       1064      812     -252     
Impacted Files Coverage Δ
scvi/external/wscvi/_module.py 84.00% <84.00%> (ø)
scvi/model/base/_demixin.py 93.43% <93.43%> (ø)
scvi/external/__init__.py 100.00% <100.00%> (ø)
scvi/external/wscvi/__init__.py 100.00% <100.00%> (ø)
scvi/external/wscvi/_model.py 100.00% <100.00%> (ø)
scvi/model/_linear_scvi.py 100.00% <100.00%> (ø)
scvi/model/_scanvi.py 93.60% <100.00%> (+0.43%) :arrow_up:
scvi/model/_scvi.py 100.00% <100.00%> (+4.54%) :arrow_up:
scvi/model/base/__init__.py 100.00% <100.00%> (ø)
scvi/model/base/_differential.py 86.82% <100.00%> (+0.69%) :arrow_up:
... and 129 more

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.

codecov[bot] avatar May 19 '21 12:05 codecov[bot]

Hi @PierreBoyeau, is this PR safe to close due to #1872?

martinkim0 avatar Feb 23 '23 04:02 martinkim0