pyro
pyro copied to clipboard
[WIP] Autonormal encoder
@martinjankowiak @fritzo @eb8680 following our conversation here https://github.com/YosefLab/scvi-tools/pull/930#pullrequestreview-661114030, creating this PR to discuss adding Autonormal encoder class.
This class need users to specify encoder network class, data transformation, and amortised_plate_sites dictionary which tells which variables are amortised, which model args/kwargs need to be passed to the encoder and which plate the variables belong to.
One of the main assumptions at the moment is that encoded variables are 2D tensors - but I guess that the shape can be automatically guessed, I just did not think that through and do not have applications where variables are more than 2D.
Looks like accidentally included changes to AutoGuideList from PR #2837 (still learning how to use git correctly).
Looks like accidentally included changes from ...
No worries, #2837 should merge soon. We often add "Blocked by #xxx" in the PR description to denote merge order dependencies.
What is missing at the moment is a simple encoder NN class. @fritzo @martinjankowiak is there anything already defined in pyro or a good example?
What is missing at the moment is a simple encoder NN class. ... is there anything already defined in pyro or a good example?
Existing code includes:
Also feel free to add something to a new file in pyro/nn/
I started thinking about tests and realised that this testing class also needs a model with local variables and some data. @fritzo do you have any good example in mind (which is ideally already implemented in tests)?
One alternative would be to use scVI pyro test regression model, and write simple training and posterior sampling code to test this class. Actually, for posterior sampling and computing median/quantiles, concatenating encoded local variables in plate dimension is non-trivial and a subject of this scVI PR (PyroSampleMixin class): https://github.com/YosefLab/scvi-tools/pull/1059
Could be good if the AutoNormalEncoder class provided a method to merge quantiles, medians and posterior samples along the plate dimension. WDYT?
First guess plate dimension:
def _guess_obs_plate_sites(self, args, kwargs):
"""
Automatically guess which model sites belong to observation/minibatch plate.
This function requires minibatch plate name specified in `self.amortised_plate_sites["name"]`.
Parameters
----------
args
Arguments to the model.
kwargs
Keyword arguments to the model.
Returns
-------
Dictionary with keys corresponding to site names and values to plate dimension.
"""
plate_name = self.amortised_plate_sites["name"]
# find plate dimension
trace = poutine.trace(self.model).get_trace(*args, **kwargs)
obs_plate = {
name: site["cond_indep_stack"][0].dim
for name, site in trace.nodes.items()
if site["type"] == "sample"
if any(f.name == plate_name for f in site["cond_indep_stack"])
}
return obs_plate
Then concatenate samples in that dimension:
i=0
for args, kwargs in dataloader:
if i==0:
samples = guide.quantiles(0.5, *args, **kwargs)
obs_plate_sites = guide._guess_obs_plate_sites(args, kwargs)
obs_plate_dim = list(obs_plate_sites.values())[0]
else:
samples_ = guide.quantiles(0.5, *args, **kwargs)
samples = {
k: np.array(
[
np.concatenate(
[samples[k][j], samples_[k][j]],
axis=obs_plate_dim,
)
for j in range(
len(samples[k])
) # for each sample (in 0 dimension
]
)
for k in samples.keys() # for each variable
}
i = i + 1
I extended this class further to enable more complex architectures (see below) and a different number of hidden nodes for each model site.
-
Single encoder NN for all pyro model sites (encoder_mode='single') where means and scales linearly depend on the last NN layer A -> site1, site2, site3 ... siteN;
-
Separate NN for each pyro model site (encoder_mode='multiple') where means and scales linearly depend on the last NN layer B -> site1; B -> site2 ... B -> siteN;
-
Single encoder NN followed by another layer of separate NN for each pyro model site (encoder_mode='single-multiple') where means and scales linearly depend on the last NN layer. Aka branching network: A -> B; B -> site1; B -> site2; ... B -> siteN;
Code is here for now: https://github.com/vitkl/scvi-tools/blob/pyro-cell2location/scvi/external/cell2location/autoguide.py
Still looking for good example data for tests. I will start working on this when we resubmit the paper revision (hopefully in August).