pyro
pyro copied to clipboard
FR Automatic reparameterization strategies
As our available reparameterization strategies become increasingly complex and composable, it would be useful to provide some automatic strategies, possibly even enabled by default in autoguides and MCMC.
Maybe this could work with syntax like poutine.reparam(config="auto")
or poutine.reparam(config="full")
.
Proposed automatic strategies
-
"minimal": A minimal strategy that reparametrizes only sites that requires reparametrization for mathematical reasons, e.g.
StableReparam
forStable
distributions,CircularReparam
#2740 forVonMises
,ProjectedNormalReparam
forProjectedNormal
sites. -
"full": Complete reparametrization after which all
pyro.sample
sites are parameter-free, as required by #2768. This should error on distributions with implicit reparameterization, e.g.Gamma
with a learnableconcentration
parameter. -
"auto": Recommended reparametrization including something like
- "minimal";
-
LocScaleReparam
to all distributions with.loc
and.scale
attributes (will require unwrappingIndependent
andMasked
distributions); -
SoftmaxReparam
forSimplex
-constrained random variables.
Design questions
- [ ] Should we automatically use these in
AutoGuide
and MCMC? If so, we'd need to refactorPredictive
to use the reparametrizedguide.model
rather than the rawmodel
(this bit me recently), and we might want to addguide.predict()
methods #2851. - [ ] This will require more composition of reparametrizers. How should this composition work? Maybe a
.lift()
method? Or will lifting be automatic?
Should we automatically use these in AutoGuide and MCMC?
Would it also make sense to refactor the automatic transformation to unconstrained space in MCMC and ADVI to use poutine.reparam
and an automated "unconstrain"
reparameterization strategy?
Also, the enumeration strategies are tantalizingly close to being Reparam
classes - I wonder if we could actually rewrite them that way and make config_enumerate
into a special case of ReparamMessenger
.
Would it also make sense to refactor the automatic transformation to unconstrained space...
Interesting, this seems similar to @fehiepsi's uniform_reparam_transform()
in https://github.com/pyro-ppl/numpyro/pull/807.
Before we do any major refactoring, I'd like to see what issues come up with the new use cases of automatic reparametrization strategies in #2884. I expect #2884 to uncover design constraints that will inform how reparametrization transforms work more generally.
In NumPyro, we used unconstrained_reparam in HMC. For autoguide, at first, I tried to use that reparam but later decided to follow Pyro implementation because SVI requires a pair of model
and (transformed) guide
, not reparamed_model
and diagonal guide
.