pyro
pyro copied to clipboard
Utilities for simplifying interactions between PyroSample and plates
Problem
PyroModule
and PyroSample
make it straightforward to compositionally specify probabilistic models with random parameters. However, PyroSample
has a somewhat awkward interaction with pyro.plate
:
class Model(pyro.nn.PyroModule):
@pyro.nn.PyroSample
def loc(self):
return pyro.distributions.Normal(0, 1)
@pyro.nn.PyroSample
def scale(self):
return pyro.distributions.LogNormal(0, 1)
def forward(self, x_obs):
assert self.scale.shape == () # accessing self.scale triggers pyro.sample outside the plate
with pyro.plate("data", x_obs.shape[0], dim=-1):
assert self.loc.shape == (x_obs.shape[0],) # accessing self.loc here triggers pyro.sample inside the plate
return pyro.sample("x", pyro.distributions.Normal(self.loc, self.scale), obs=x_obs)
To ensure loc
and scale
are sampled globally, it is necessary to access them outside the data
plate as scale
is in the above - inlining self.loc
in the final line samples a different loc
for each datapoint. This behavior is unambiguous semantically, but it can cause confusion in more complex models and require lots of ugly boilerplate code in the model that manually samples random parameters of submodules in the correct plate context.
For example, in the below code the intuitive behavior for Model.linear
is clearly for linear.weight
to be sampled outside of the data
plate, but because self.linear
is invoked for the first time inside the plate, there will be separate random copies of linear.weight
for each plate slice:
class BayesianLinear(pyro.nn.PyroModule[torch.nn.Linear]):
@pyro.nn.PyroSample
def weight(self):
return dist.Normal(0, 1).expand([self.num_input, self.num_output]).to_event(2)
class Model(pyro.nn.PyroModule):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.linear = BayesianLinear(num_inputs, num_outputs)
def forward(self, x):
with pyro.plate("data", x.shape[-2], dim=-1):
loc = self.linear(x)
assert self.linear.weight.shape[-3] == x.shape[-2]
return pyro.sample("y", dist.Normal(loc, 1))
However, it would not be correct to simply ignore all plates when executing PyroSample
s - in this example, we might want to use a multi-sample ELBO estimator in inferring self.linear.weight
(e.g. pyro.infer.Trace_ELBO(num_particles=10, vectorize_particles=True)
), which is implemented with another plate
that should not be ignored.
Proposed fix
It would be nice to have a feature that enabled the intuitive behavior in the second example above without breaking backwards compatibility with PyroSample
's existing semantics or its correctness in the presence of enclosing plates like that introduced by the multi-sample ELBO.
This could potentially be achieved with a new handler PyroSamplePlateScope
such that PyroSample
statements executed inside its context are only modified by plates entered outside of it, while ordinary pyro.sample
statements are unaffected and behave in the usual way:
class Model(pyro.nn.PyroModule):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.num_inputs = num_inputs
self.num_outputs = num_outputs
self.linear = BayesianLinear(num_inputs, num_outputs)
@pyro.nn.PyroSample
def scale(self):
return pyro.distributions.LogNormal(0, 1).expand([self.num_outputs]).to_event(1)
@PyroSamplePlateScope()
def forward(self, x):
with pyro.plate("data", x.shape[-2], dim=-1):
loc = self.linear(x)
assert self.linear.weight.shape[-3] == 1 # sampled outside data plate
assert self.scale.shape[-2] == 1 # sampled outside data plate
y = pyro.sample("y", dist.Normal(loc, self.scale).to_event(1))
assert y.shape[-2] == x.shape[-2] # ordinary pyro.sample statement
return y