funsor
funsor copied to clipboard
Full-featured distribution wrappers
The biggest single issue blocking wider use of Funsor in Pyro and NumPyro right now is the incomplete coverage of distributions.
At a high level, the goal is to be able to perform all distribution operations that appear in any Pyro or NumPyro model (e.g. sampling and scoring) on Funsors directly, where distribution funsors are obtained from using to_funsor
to automatically convert distributions to Funsors initially and using funsor.to_data
to automatically convert the final results back to raw PyTorch/JAX objects. Wherever possible, the Funsor wrappers should also avoid the need for user-facing higher-order distributions, such as Independent
or TransformedDistribution
, in favor of idiomatic Funsor operations or broadcasting semantics.
We have gotten pretty far with the generic wrappers in funsor.distribution
and funsor.{jax,torch}.distributions
, but finishing the job and achieving full coverage of pyro.distributions
remains a challenge because of the large distribution API and number of distributions, many small impedance mismatches and legacy design choices in PyTorch (e.g. the data type of Bernoulli
), and difficulty of programmatic access and automation (e.g. no generic tool for constructing random valid instances of a distribution given a batch_shape
).
I've tried to collect the remaining Funsor-specific tasks in this issue so that we can better measure progress toward this goal. We may also need to do additional work upstream in Pyro, NumPyro or PyTorch distributions.
Transforms and TransformedDistribution
s (some design discussion in #309):
- [x] #387 Near-complete coverage of parameter-free invertible
Transform
s in Pyro - [x] #365 Basic conversion of
TransformedDistribution
s to and fromfunsor.Distributions
on the PyTorch backend - [x] #427 Automatic wrapping of non-batched
Transform
s - [ ] Conversion of
TransformedDistribution
s to and from funsors on the JAX backend (can copy #365) - [ ] Coverage for distribution conversion with non-invertible
Transform
s likeAbsTransform
- [ ] Coverage for conversion of invertible
Transform
s with ground parameters, notablyAffineTransform
andPowerTransform
withbatch_shape == ()
- [ ] Coverage for conversion of invertible
Transform
with non-ground parameters, e.g.AffineTransform
withlen(batch_shape) > 0
- [ ] Wrappers for
ConditionalTransform
s andTransformedModule
s in Pyro/PyTorch
Other basic distribution modifiers:
- [x] #391 automated
to_funsor
conversion for custom distributions - [x] #394 Near-full coverage of basic distributions (easy enough to implement; the difficulty is testing and fixing PyTorch bugs)
- [x] #394 Get backend
Delta
conversion working - [x] #396 Converting
funsor.Independent
to and fromIndependent
distributions - [x] #402 more idiomatic support for
Independent
distributions in Funsor - [x] #418 Converting lazy Pyro
ExpandedDistribution
s to funsors - [x] #419 Converting lazy Numpyro
ExpandedDistribution
s to funsors - [x] #432 Handle conversion of parameters with
_IndependentConstraint
arg_constraints
- [x] #443 Convert
IndependentDistribution
s directly to base Funsor distributions - [ ] Support for converting mixture distributions into funsor sum-product expressions, e.g.
MaskedMixture
andMixtureSameFamily
Masking:
- [ ] First-class support for masking using logical operations rather than floats
- [ ] Optimized handling for scalar masks (i.e. the case
mask is False
) (discussed in #459)
Direct TFP distribution wrappers:
- [ ] Direct wrapping of TFP distributions in the JAX backend - this would probably involve a new subclass
class TFPDistribution(funsor.distribution.Distribution)
with TFP-specific implementations of_infer_value_domain
and_infer_param_domain
- [ ] Direct TFP
Bijector
wrappers
Atomic distribution computations beyond sampling and scoring implemented in the backend libraries:
- [x] #388 Entropy (discussed in #374)
- [x] #388 Mean (
TorchDistribution.mean
) - [x] #388 Variance (
TorchDistribution.variance
) - [ ] KL divergence (discussed in #374)
Test harnesses for distribution wrappers (testing correctness of underlying distribution functionality here is out of scope - we are mostly interested in ensuring that results are converted to Funsors correctly):
- [x] #389 conversion with
to_funsor
andto_data
- [x] #389 types and shapes
- [x] #389 densities
- [x] #389 enumeration (
enumerate_support
) - [x] #389 samplers
- [ ] transforms
Miscellaneous:
- [ ] Find a workaround for int/float casting issues in
Bernoulli
(discussed in #348) - [ ] Generic support for sampling via
Funsor.sample()
infunsor.pyro.FunsorDistribution
- done properly, this may eliminate the need for first-class conjugate distribution implementations in backends, e.g.DirichletMultinomial
Lower priority, possibly unnecessary:
- conjugate_update wrapper (may be unnecessary?)
- #395 Monte Carlo tests for samplers and their gradients (low priority)
- Some Funsor analogue of
Distribution.expand()
e.g. via lazy units as in #235 or substitution of indexedVariable
terms (f(v=v['i'])
) - Generic tests for conjugate distribution pairs
- deprecate existing, case-by-case tests
Conversion of TransformedDistributions to and from funsors on the JAX backend (can copy #365)
Note this will be a bit trickier than expected because in NumPyro Transform.inv
is just a regular method, rather than returning an _InverseTransform
, which NumPyro does not implement.
I will make a PR for that. :)