pyro icon indicating copy to clipboard operation
pyro copied to clipboard

Utilities for simplifying interactions between PyroSample and plates

Open eb8680 opened this issue 7 months ago • 2 comments


PyroModule and PyroSample make it straightforward to compositionally specify probabilistic models with random parameters. However, PyroSample has a somewhat awkward interaction with pyro.plate:

class Model(pyro.nn.PyroModule):

  def loc(self):
    return pyro.distributions.Normal(0, 1)

  def scale(self):
    return pyro.distributions.LogNormal(0, 1)

  def forward(self, x_obs):
    assert self.scale.shape == ()  # accessing self.scale triggers pyro.sample outside the plate
    with pyro.plate("data", x_obs.shape[0], dim=-1):
      assert self.loc.shape == (x_obs.shape[0],)  # accessing self.loc here triggers pyro.sample inside the plate
      return pyro.sample("x", pyro.distributions.Normal(self.loc, self.scale), obs=x_obs)

To ensure loc and scale are sampled globally, it is necessary to access them outside the data plate as scale is in the above - inlining self.loc in the final line samples a different loc for each datapoint. This behavior is unambiguous semantically, but it can cause confusion in more complex models and require lots of ugly boilerplate code in the model that manually samples random parameters of submodules in the correct plate context.

For example, in the below code the intuitive behavior for Model.linear is clearly for linear.weight to be sampled outside of the data plate, but because self.linear is invoked for the first time inside the plate, there will be separate random copies of linear.weight for each plate slice:

class BayesianLinear(pyro.nn.PyroModule[torch.nn.Linear]):

  def weight(self):
    return dist.Normal(0, 1).expand([self.num_input, self.num_output]).to_event(2)

class Model(pyro.nn.PyroModule):
  def __init__(self, num_inputs, num_outputs):
    self.linear = BayesianLinear(num_inputs, num_outputs)

  def forward(self, x):
    with pyro.plate("data", x.shape[-2], dim=-1):
      loc = self.linear(x)
      assert self.linear.weight.shape[-3] == x.shape[-2]
      return pyro.sample("y", dist.Normal(loc, 1))

However, it would not be correct to simply ignore all plates when executing PyroSamples - in this example, we might want to use a multi-sample ELBO estimator in inferring self.linear.weight (e.g. pyro.infer.Trace_ELBO(num_particles=10, vectorize_particles=True)), which is implemented with another plate that should not be ignored.

Proposed fix

It would be nice to have a feature that enabled the intuitive behavior in the second example above without breaking backwards compatibility with PyroSample's existing semantics or its correctness in the presence of enclosing plates like that introduced by the multi-sample ELBO.

This could potentially be achieved with a new handler PyroSamplePlateScope such that PyroSample statements executed inside its context are only modified by plates entered outside of it, while ordinary pyro.sample statements are unaffected and behave in the usual way:

class Model(pyro.nn.PyroModule):
  def __init__(self, num_inputs, num_outputs):
    self.num_inputs = num_inputs
    self.num_outputs = num_outputs
    self.linear = BayesianLinear(num_inputs, num_outputs)

  def scale(self):
    return pyro.distributions.LogNormal(0, 1).expand([self.num_outputs]).to_event(1)

  def forward(self, x):
    with pyro.plate("data", x.shape[-2], dim=-1):
      loc = self.linear(x)
      assert self.linear.weight.shape[-3] == 1  # sampled outside data plate
      assert self.scale.shape[-2] == 1  # sampled outside data plate
      y = pyro.sample("y", dist.Normal(loc, self.scale).to_event(1))
      assert y.shape[-2] == x.shape[-2]  # ordinary pyro.sample statement
      return y

eb8680 avatar Jul 17 '24 15:07 eb8680