funsor
funsor copied to clipboard
Exploit opportunities for analytic KL and entropy computations
Motivated by @fehiepsi's work on TraceMeanField_ELBO
in NumPyro https://github.com/pyro-ppl/numpyro/pull/748 and ongoing issues with the LDA example in Pyro. cc @fritzo @martinjankowiak
It is common in Pyro to use mean-field variational distributions with TraceMeanField_ELBO
to reduce ELBO estimator variance. However, TraceMeanField_ELBO
is too conservative and cannot be used for only part of a model or combined with Pyro's other inference tools, notably enumeration of discrete variables, making it difficult to perform variational inference reliably in models like LDA.
A better long-term approach would be to automatically identify ELBO fragments that admit analytic computation. Most such fragments are determined (almost) nonparametrically by conditional independence properties of the model and guide, so it should be possible in principle to cover a surprisingly wide range of Pyro models.
We can decompose this pattern-matching problem into two stages using Funsor: patterns for recognizing situations within a larger computation where analytic KL divergence and entropy computations may be used, and patterns for actually performing these computations using the backend distribution libraries' preexisting optimized implementations. These patterns could then be used seamlessly within any Funsor-based ELBO implementation, notably pyro.contrib.funsor.infer.TraceEnum_ELBO
.
funsor.optimize.optimize
already decomposes ELBO computations into conditionally independent fragments, although there are missing details like constant propagation that need to be handled with more generality (see also #163 #109).
Thus, at a high level, for the first stage we'll just need to add patterns that rewrite Monte Carlo expectations back to analytic versions. Obviously this is only applicable when we can guarantee that the Monte Carlo measure is drawn from the same distribution, so these patterns would have to live in their own special interpretation:
@dispatched_interpretation
def analytic_recognizer(cls, *args):
return analytic_recognizer.dispatch(cls, *args)(*args)
@analytic_recognizer.register(Integrate, Delta, Distribution, frozenset)
def recognize_analytic_entropy(log_measure, integrand, reduced_vars):
... # check that the rewrite can be performed
return Integrate(integrand, integrand, reduced_vars)
For added robustness, analytic_recognizer
could be a StatefulInterpretation
holding a mapping from Delta
funsors to their sampling distribution funsors.
For the second stage, we'll need eager
patterns that are evaluated using the backend .entropy
or kl
implementations:
@eager.register(Integrate, Distribution, Distribution, frozenset)
def eager_analytic_entropy(log_measure, integrand, reduced_vars):
name_to_dim, dim_to_name = ... # arbitrary name-dim mapping
entropy_raw = -funsor.to_data(integrand, name_to_dim).entropy() # call TorchDistribution.entropy()
return funsor.to_funsor(entropy_raw, funsor.Real, dim_to_name)
With these patterns in hand, computing analytic entropy or KL terms in pyro.contrib.funsor.infer.TraceEnum_ELBO
shouldn't involve too much beyond using the new analytic_recognizer
interpretation when evaluating the final ELBO funsor expression.