funsor icon indicating copy to clipboard operation
funsor copied to clipboard

Full-featured distribution wrappers

Open eb8680 opened this issue 4 years ago • 2 comments

The biggest single issue blocking wider use of Funsor in Pyro and NumPyro right now is the incomplete coverage of distributions.

At a high level, the goal is to be able to perform all distribution operations that appear in any Pyro or NumPyro model (e.g. sampling and scoring) on Funsors directly, where distribution funsors are obtained from using to_funsor to automatically convert distributions to Funsors initially and using funsor.to_data to automatically convert the final results back to raw PyTorch/JAX objects. Wherever possible, the Funsor wrappers should also avoid the need for user-facing higher-order distributions, such as Independent or TransformedDistribution, in favor of idiomatic Funsor operations or broadcasting semantics.

We have gotten pretty far with the generic wrappers in funsor.distribution and funsor.{jax,torch}.distributions, but finishing the job and achieving full coverage of pyro.distributions remains a challenge because of the large distribution API and number of distributions, many small impedance mismatches and legacy design choices in PyTorch (e.g. the data type of Bernoulli), and difficulty of programmatic access and automation (e.g. no generic tool for constructing random valid instances of a distribution given a batch_shape).

I've tried to collect the remaining Funsor-specific tasks in this issue so that we can better measure progress toward this goal. We may also need to do additional work upstream in Pyro, NumPyro or PyTorch distributions.

Transforms and TransformedDistributions (some design discussion in #309):

  • [x] #387 Near-complete coverage of parameter-free invertible Transforms in Pyro
  • [x] #365 Basic conversion of TransformedDistributions to and from funsor.Distributions on the PyTorch backend
  • [x] #427 Automatic wrapping of non-batched Transforms
  • [ ] Conversion of TransformedDistributions to and from funsors on the JAX backend (can copy #365)
  • [ ] Coverage for distribution conversion with non-invertible Transforms like AbsTransform
  • [ ] Coverage for conversion of invertible Transforms with ground parameters, notably AffineTransform and PowerTransform with batch_shape == ()
  • [ ] Coverage for conversion of invertible Transform with non-ground parameters, e.g. AffineTransform with len(batch_shape) > 0
  • [ ] Wrappers for ConditionalTransforms and TransformedModules in Pyro/PyTorch

Other basic distribution modifiers:

  • [x] #391 automated to_funsor conversion for custom distributions
  • [x] #394 Near-full coverage of basic distributions (easy enough to implement; the difficulty is testing and fixing PyTorch bugs)
  • [x] #394 Get backend Delta conversion working
  • [x] #396 Converting funsor.Independent to and from Independent distributions
  • [x] #402 more idiomatic support for Independent distributions in Funsor
  • [x] #418 Converting lazy Pyro ExpandedDistributions to funsors
  • [x] #419 Converting lazy Numpyro ExpandedDistributions to funsors
  • [x] #432 Handle conversion of parameters with _IndependentConstraint arg_constraints
  • [x] #443 Convert IndependentDistributions directly to base Funsor distributions
  • [ ] Support for converting mixture distributions into funsor sum-product expressions, e.g. MaskedMixture and MixtureSameFamily

Masking:

  • [ ] First-class support for masking using logical operations rather than floats
  • [ ] Optimized handling for scalar masks (i.e. the case mask is False) (discussed in #459)

Direct TFP distribution wrappers:

  • [ ] Direct wrapping of TFP distributions in the JAX backend - this would probably involve a new subclass class TFPDistribution(funsor.distribution.Distribution) with TFP-specific implementations of _infer_value_domain and _infer_param_domain
  • [ ] Direct TFP Bijector wrappers

Atomic distribution computations beyond sampling and scoring implemented in the backend libraries:

  • [x] #388 Entropy (discussed in #374)
  • [x] #388 Mean (TorchDistribution.mean)
  • [x] #388 Variance (TorchDistribution.variance)
  • [ ] KL divergence (discussed in #374)

Test harnesses for distribution wrappers (testing correctness of underlying distribution functionality here is out of scope - we are mostly interested in ensuring that results are converted to Funsors correctly):

  • [x] #389 conversion with to_funsor and to_data
  • [x] #389 types and shapes
  • [x] #389 densities
  • [x] #389 enumeration (enumerate_support)
  • [x] #389 samplers
  • [ ] transforms

Miscellaneous:

  • [ ] Find a workaround for int/float casting issues in Bernoulli (discussed in #348)
  • [ ] Generic support for sampling via Funsor.sample() in funsor.pyro.FunsorDistribution - done properly, this may eliminate the need for first-class conjugate distribution implementations in backends, e.g. DirichletMultinomial

Lower priority, possibly unnecessary:

  • conjugate_update wrapper (may be unnecessary?)
  • #395 Monte Carlo tests for samplers and their gradients (low priority)
  • Some Funsor analogue of Distribution.expand() e.g. via lazy units as in #235 or substitution of indexed Variable terms (f(v=v['i']))
  • Generic tests for conjugate distribution pairs
  • deprecate existing, case-by-case tests

eb8680 avatar Oct 26 '20 18:10 eb8680

Conversion of TransformedDistributions to and from funsors on the JAX backend (can copy #365)

Note this will be a bit trickier than expected because in NumPyro Transform.inv is just a regular method, rather than returning an _InverseTransform, which NumPyro does not implement.

eb8680 avatar Nov 27 '20 19:11 eb8680

I will make a PR for that. :)

fehiepsi avatar Nov 27 '20 23:11 fehiepsi