pyro icon indicating copy to clipboard operation
pyro copied to clipboard

extending Predictive to support sampling by plate name

Open vitkl opened this issue 4 years ago • 5 comments

Following our earlier discussion with @fritzo and @adamgayoso , creating this issue. Would be great if Predictive class can support sampling of all variables in a given plate, automatically pulling all site values in a given plate(plate_name) or all sites not is that plate, including both samples and observed values:

trace = poutine.trace(model).get_trace(*args, **kwargs)  # you'll need to fill in args,kwargs
cell_samples = {
    name: site["value"]
    for name, site in trace.nodes.items()
    if site["type"] == "sample"
    if any(f.name == plate_name for f in site["cond_indep_stack"])
}

This functionality would simplify dealing with local variables within a minibatch plate, by making it easy to do separate predictive calls for the minibatch plate variables and global variables.

vitkl avatar Mar 16 '21 16:03 vitkl

@vitkl @jamestwebber can you elaborate on the task you're trying to accomplish? From @jamestwebber's comment

I find myself in a similar situation, trying to concatenate posterior samples

it sounds like Predictive is missing support for vectorization somewhere, but I suspect I'm misunderstanding.

fritzo avatar Mar 16 '21 18:03 fritzo

it sounds like Predictive is missing support for vectorization somewhere, but I suspect I'm misunderstanding.

It's very possible that I just don't understand the right pattern for accomplishing this. For me, the intended workflow is something like

  1. Learn a model of some large dataset using mini-batches for training
  2. Make a Predictive instance with pred = Predictive(model, guide=guide, num_samples=1000)
  3. Collect the posterior on my large dataset by iterating through the same data loader
posteriors = []
for x, y in data_loader:
    posteriors.append(pred(x, y))

aggregate_posterior_minibatches()

So that in the end I'd have a posterior over both the global variables (i.e. some background effect) and the minibatch-specific ones (for example, cell-specific loading factors).

jamestwebber avatar Mar 16 '21 19:03 jamestwebber

@jamestwebber that makes sense, thanks! I believe it should be easy to adapt the logic in Forecaster, which manually batches and torch.cats the batches back to a single tensor.

fritzo avatar Mar 16 '21 19:03 fritzo

Yes, exactly like @jamestwebber described - we need a way to sample both the global variables and the local variables (including amortised variables). The idea is that by doing predictive sampling by plate name you can isolate local variables from global variables and sample them via 2 separate calls. I think 2 separate calls are needed because if you iterate over the data loader global variables will be sampled in every iteration of the data loader whereas local variable will be only sampled once in their respective minibatch. Separating allows sampling global variable once (where I count y_fg as global: d_cg = x_cf @ y_fg and x_cf is minibatched in c plate).

I see 2 separate problems:

  1. How to tell predictive to sample local variables from a batch plate
  2. How to concatenate the batches of local variables correctly (which need recognising the plate dimension)

vitkl avatar Mar 17 '21 02:03 vitkl

I defined a method _posterior_samples_minibatch for doing this for scRNA-seq models in a recent PR to scVI: https://github.com/YosefLab/scvi-tools/blob/b020ff5fef9248df17f07d7ed1a6f794d8f7673e/scvi/model/base/_pyromixin.py#L262-L351 It uses poutine to generate samples one by one and is pretty fast (https://github.com/YosefLab/scvi-tools/blob/b020ff5fef9248df17f07d7ed1a6f794d8f7673e/scvi/model/base/_pyromixin.py#L133-L180). However, this approach is specific to scVI because data loader is a part of the method.

vitkl avatar May 17 '21 15:05 vitkl