funsor icon indicating copy to clipboard operation
funsor copied to clipboard

Exploit opportunities for analytic KL and entropy computations

Open eb8680 opened this issue 4 years ago • 0 comments

Motivated by @fehiepsi's work on TraceMeanField_ELBO in NumPyro https://github.com/pyro-ppl/numpyro/pull/748 and ongoing issues with the LDA example in Pyro. cc @fritzo @martinjankowiak

It is common in Pyro to use mean-field variational distributions with TraceMeanField_ELBO to reduce ELBO estimator variance. However, TraceMeanField_ELBO is too conservative and cannot be used for only part of a model or combined with Pyro's other inference tools, notably enumeration of discrete variables, making it difficult to perform variational inference reliably in models like LDA.

A better long-term approach would be to automatically identify ELBO fragments that admit analytic computation. Most such fragments are determined (almost) nonparametrically by conditional independence properties of the model and guide, so it should be possible in principle to cover a surprisingly wide range of Pyro models.

We can decompose this pattern-matching problem into two stages using Funsor: patterns for recognizing situations within a larger computation where analytic KL divergence and entropy computations may be used, and patterns for actually performing these computations using the backend distribution libraries' preexisting optimized implementations. These patterns could then be used seamlessly within any Funsor-based ELBO implementation, notably pyro.contrib.funsor.infer.TraceEnum_ELBO.

funsor.optimize.optimize already decomposes ELBO computations into conditionally independent fragments, although there are missing details like constant propagation that need to be handled with more generality (see also #163 #109).

Thus, at a high level, for the first stage we'll just need to add patterns that rewrite Monte Carlo expectations back to analytic versions. Obviously this is only applicable when we can guarantee that the Monte Carlo measure is drawn from the same distribution, so these patterns would have to live in their own special interpretation:

@dispatched_interpretation
def analytic_recognizer(cls, *args):
    return analytic_recognizer.dispatch(cls, *args)(*args)

@analytic_recognizer.register(Integrate, Delta, Distribution, frozenset)
def recognize_analytic_entropy(log_measure, integrand, reduced_vars):
    ...  # check that the rewrite can be performed
    return Integrate(integrand, integrand, reduced_vars)

For added robustness, analytic_recognizer could be a StatefulInterpretation holding a mapping from Delta funsors to their sampling distribution funsors.

For the second stage, we'll need eager patterns that are evaluated using the backend .entropy or kl implementations:

@eager.register(Integrate, Distribution, Distribution, frozenset)
def eager_analytic_entropy(log_measure, integrand, reduced_vars):
    name_to_dim, dim_to_name = ...  # arbitrary name-dim mapping
    entropy_raw = -funsor.to_data(integrand, name_to_dim).entropy()  # call TorchDistribution.entropy()
    return funsor.to_funsor(entropy_raw, funsor.Real, dim_to_name)

With these patterns in hand, computing analytic entropy or KL terms in pyro.contrib.funsor.infer.TraceEnum_ELBO shouldn't involve too much beyond using the new analytic_recognizer interpretation when evaluating the final ELBO funsor expression.

eb8680 avatar Sep 29 '20 19:09 eb8680