pyro
pyro copied to clipboard
extending Predictive to support sampling by plate name
Following our earlier discussion with @fritzo and @adamgayoso , creating this issue. Would be great if Predictive class can support sampling of all variables in a given plate, automatically pulling all site values in a given plate(plate_name) or all sites not is that plate, including both samples and observed values:
trace = poutine.trace(model).get_trace(*args, **kwargs) # you'll need to fill in args,kwargs
cell_samples = {
name: site["value"]
for name, site in trace.nodes.items()
if site["type"] == "sample"
if any(f.name == plate_name for f in site["cond_indep_stack"])
}
This functionality would simplify dealing with local variables within a minibatch plate, by making it easy to do separate predictive calls for the minibatch plate variables and global variables.
@vitkl @jamestwebber can you elaborate on the task you're trying to accomplish? From @jamestwebber's comment
I find myself in a similar situation, trying to concatenate posterior samples
it sounds like Predictive is missing support for vectorization somewhere, but I suspect I'm misunderstanding.
it sounds like
Predictiveis missing support for vectorization somewhere, but I suspect I'm misunderstanding.
It's very possible that I just don't understand the right pattern for accomplishing this. For me, the intended workflow is something like
- Learn a model of some large dataset using mini-batches for training
- Make a
Predictiveinstance withpred = Predictive(model, guide=guide, num_samples=1000) - Collect the posterior on my large dataset by iterating through the same data loader
posteriors = []
for x, y in data_loader:
posteriors.append(pred(x, y))
aggregate_posterior_minibatches()
So that in the end I'd have a posterior over both the global variables (i.e. some background effect) and the minibatch-specific ones (for example, cell-specific loading factors).
@jamestwebber that makes sense, thanks! I believe it should be easy to adapt the logic in Forecaster, which manually batches and torch.cats the batches back to a single tensor.
Yes, exactly like @jamestwebber described - we need a way to sample both the global variables and the local variables (including amortised variables). The idea is that by doing predictive sampling by plate name you can isolate local variables from global variables and sample them via 2 separate calls. I think 2 separate calls are needed because if you iterate over the data loader global variables will be sampled in every iteration of the data loader whereas local variable will be only sampled once in their respective minibatch. Separating allows sampling global variable once (where I count y_fg as global: d_cg = x_cf @ y_fg and x_cf is minibatched in c plate).
I see 2 separate problems:
- How to tell predictive to sample local variables from a batch plate
- How to concatenate the batches of local variables correctly (which need recognising the plate dimension)
I defined a method _posterior_samples_minibatch for doing this for scRNA-seq models in a recent PR to scVI:
https://github.com/YosefLab/scvi-tools/blob/b020ff5fef9248df17f07d7ed1a6f794d8f7673e/scvi/model/base/_pyromixin.py#L262-L351
It uses poutine to generate samples one by one and is pretty fast (https://github.com/YosefLab/scvi-tools/blob/b020ff5fef9248df17f07d7ed1a6f794d8f7673e/scvi/model/base/_pyromixin.py#L133-L180). However, this approach is specific to scVI because data loader is a part of the method.