scvi-tools
scvi-tools copied to clipboard
[WIP] lvm-DE code
Fixes #1068
Here is the code we discussed for lvm-DE, an importance-sampling-based procedure for calibrated DE gene predictions.
A parallel PR will be opened asap for the associated notebook introducing the module.
Depending on your call on the place of this tool in scvi-tools, it could require some additional work (see questions).
Presentation of the PR
New files
-
external/wscvi: In line with existing external packages, thewscvifolder contains the code to run an SCVI version that easily yields variational distributions, priors densities, as well as likelihoods. Having access to these quantities is used for two reasons. First, lvm-DE relies on importance-sampling, and hence directly requires these quantities. A second reason is that this allows training scVI models with IWELBO which arguably improves the stability of lvm-DE. -
model/base/_demixin.pycontains thelvm-DEcore. I tried as much as possible to document it, but it requires additional documentation. In particular, thelvm_demethod applies the same mechanisms asdifferential_expressionin_rnamixin.py. It also usesDifferentialComputationunder the hood.
Modified files
The main modification I made was to make DifferentialComputation usable with lvm-DE.
To get a given number of gene normalized expression samples for a selection of indices, DifferentialComputation
randomly samples cell indices, using
idx = np.random.choice(np.arange(self.adata.shape[0])[selection], n_samples)
Next, we get one sample for each index using the variational distribution directly. This reasoning does not work well for lvm-DE.
Consequently, I modified DifferentialComputation such that it lets model_fn take care of the sampling instead, through n_samples_overall.
For instance, when using the current differential tools, model_fn now takes a n_samples_overall parameter, which takes care of calling np.random.choice the same way as before.
While this is a small change, it required me to change each so-mentioned function in the codebase (in _peakvi.py, _scvi.py, and _totalvi.py).
I also induced code duplication, which I am not fond of but will try to remove.
I also allowed the get_marginal_ll methods to return observation-specific marginal densities as an option, as lvm-DE requires these densities.
Questions
Here are some uncertainties I have on what we should do with this code.
-
Should we somehow "merge" WSCVI and SCVI? In other words, should we provide an option
n_particlesto SCVI to switch from a mode computing KL divergences to a one relying on log-ratios to train the model? An easy first step in this direction consists in returningtorch.Distributionininferenceandgenerative. From there, we have two options. First, we could let thelossfunction decide how to compute the loss (usingkl_divergencewhen possible and densities otherwise). In that scenario, we would have a somehow unifiedVAE/WVAEmodel. This could however make the module heavier, and I don't know if this is a good idea. The second option, more in phase with whatWVAEdoes, would be to keep the log-ratio-friendly model on the side. In any case, I think that aWVAEinspired structure can help developers in certain ways that go beyond DE. Indeed, it can easily be changed to consider alternatives inference networks (e.g., using flows) and priors forz(hyperspherical, vamp priors, etc.), or maybe more easily consider hierarchical models. -
Where should we keep WSCVI? I think the way to go is to keep it as an external package, but what do you think?
Codecov Report
Base: 90.41% // Head: 90.99% // Increases project coverage by +0.57% :tada:
Coverage data is based on head (
7b64282) compared to base (47bf4a5). Patch coverage: 92.44% of modified lines in pull request are covered.
:exclamation: Current head 7b64282 differs from pull request most recent head 05eced2. Consider uploading reports for the commit 05eced2 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## main #1067 +/- ##
==========================================
+ Coverage 90.41% 90.99% +0.57%
==========================================
Files 141 115 -26
Lines 11099 9014 -2085
==========================================
- Hits 10035 8202 -1833
+ Misses 1064 812 -252
| Impacted Files | Coverage Δ | |
|---|---|---|
| scvi/external/wscvi/_module.py | 84.00% <84.00%> (ø) |
|
| scvi/model/base/_demixin.py | 93.43% <93.43%> (ø) |
|
| scvi/external/__init__.py | 100.00% <100.00%> (ø) |
|
| scvi/external/wscvi/__init__.py | 100.00% <100.00%> (ø) |
|
| scvi/external/wscvi/_model.py | 100.00% <100.00%> (ø) |
|
| scvi/model/_linear_scvi.py | 100.00% <100.00%> (ø) |
|
| scvi/model/_scanvi.py | 93.60% <100.00%> (+0.43%) |
:arrow_up: |
| scvi/model/_scvi.py | 100.00% <100.00%> (+4.54%) |
:arrow_up: |
| scvi/model/base/__init__.py | 100.00% <100.00%> (ø) |
|
| scvi/model/base/_differential.py | 86.82% <100.00%> (+0.69%) |
:arrow_up: |
| ... and 129 more |
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.
:umbrella: View full report at Codecov.
:loudspeaker: Do you have feedback about the report comment? Let us know in this issue.
Hi @PierreBoyeau, is this PR safe to close due to #1872?