pyro
pyro copied to clipboard
FR Streaming MCMC interface for big models
This issue proposes a streaming architecture for MCMC on models with large memory footprint.
The problem this addresses is that, in models with high-dimensional latents (say >1M latent variables), it becomes difficult to save a list of samples, especially on GPUs with limited memory. The proposed solution is to eagerly compute statistics on those samples, and discard them during inference.
@fehiepsi suggested creating a new MCMC class (say StreamingMCMC) with similar interface to MCMC and still independent of kernel (using either HMC or NUTS) but that follows an internal streaming architecture. Since large models like these usually run on GPU or are otherwise memory constrained, it is reasonable to avoid multiprocessing support in StreamingMCMC.
Along with the new StreamingMCMC class I think there should be a set of helpers to streamingly compute statistics from sample streams, e.g. mean, variance, covariance, r_hat statistics.
Tasks (to be split into multiple PRs)
@mtsokol
- [x] #2857 Create a
StreamingMCMCclass with interface identical to MCMC (except disallowing parallel chains). - [x] #2857 Generalize unit tests of
MCMCto parametrize over bothMCMCandStreamingMCMC - [ ] Add some tests ensuring
StreamingMCMCandMCMCperform identical computations, up to numerical precision - [ ] Create a tutorial using
StreamingMCMCon a big model
@fritzo
- [x] #2856 Create streaming helpers for mean, variance, etc.
- [ ] Add
r_hatto pyro.ops.streaming - [ ] Add
n_eff = essto pyro.ops.streaming
Hi @fritzo!
I was searching for an issue and if this one's free I would like to try solving it.
So as I understand the main point here would be to implement StreamingMCMC that doesn't contain get_samples method and keeps in its state incrementally updated statistics (if all of them can be incrementally computed, can they?). Something like this:
class StreamingMCMC:
def __init__(...):
self.incremental_mean = ...
self.incremental_variance = ...
# and the rest of statistics that will be used by 'summary' method
def run(self, *args, **kwargs):
...
for x, chain_id in self.sampler.run(*args, **kwargs):
num_samples += 1
self.incremental_mean += (x - self.incremental_mean) / num_samples
# ...and the rest of statistics
del x
...
def diagnostics(self):
...
def summary(self, prob=0.9):
# just returns computed incremental statistics
Also in test_mcmc_api.py additional tests should run both types of MCMC classes and compare final statistics.
As described, StreamingMCMC shouldn't support multiprocessing manually via _Worker because here CUDA, which is thought to be the main beneficiary of this new class, handles vectorization by itself. (is it correct?)
Follow up question: Should StreamingMCMC have num_chains argument and for num_chains>1 just compute them sequentially or omit this argument?
It's seemingly straightforward, but I've just started looking at the source code. Are there any pitfalls that I should bear in mind?
Hi @mtsokol that sounds great and I'm happy to provide any review and guidance.
Your sketch looks good. The only difference I'd suggest would be for us to think hard about making a fully extensible interface for computing streaming statistics, so that users can easily stream other custom things. I was thinking with task 2 above to create a new module say pyro.ops.streaming with a class hierarchy of basic streamable statistics
from abc import ABC, abstractmethod
class StreamingStatistic(ABC):
"""Base class for streamable statistics"""
@abstractmethod
def update(self, sample: Dict[str, torch.Tensor]) -> None:
"""Update state from a single sample."""
raise NotImplementedError
@abstractmethod
def merge(self, other: StreamingStatistic) -> StreamingStatistic:
"""Combine two aggregate statistics, e.g. from different chains."""
assert type(self) == type(other)
raise NotImplementedError
@abstractmethod
def get(self) -> Dict[str, torch.Tensor]:
"""Return the aggregate statistic."""
raise NotImplementedError
Together with a set of basic concrete statistics (see also pyro.ops.welford for implementation but non-general interface)
class Count(StreamingStatistic): ...
class Mean(StreamingStatistic): ...
class MeanAndVariance(StreamingStatistic): ...
class MeanAndCovariance(StreamingStatistic): ...
class RHat(StreamingStatistic): ...
And maybe a restriction to a subset of names
class SubsetStatistic(StreamingStatistic):
def __init__(self, names : Set[str], base_stat: StreamingStatistic):
self.names = names
self.base_stat
def update(self, sample):
sample = {k: v for k, v in sample.items() if k in self.names}
self.base_stat.update(sample)
def get(self):
return self.base_stat.get()
I think that might be enough of an interface, but we might want more details in the __init__ methods.
Then once we have basic statistics we can make your interface generic and extensible:
class StreamingMCMC:
def __init__(..., statistics=None):
if statistics is None:
statistics = [Count(), MeanAndVariance()]
self._statistics = statistics
def run(self, *args, **kwargs):
...
for x, chain_id in self.sampler.run(*args, **kwargs):
num_samples += 1
for stat in self._statistics:
stat.update(x)
del x
...
def diagnostics(self):
...
def summary(self, prob=0.9):
# just returns computed incremental statistics
What I'd really like is to be able to define custom statistics for a particular problem, e.g. saving a list of norms
class ListOfNorms(StreamingStatistic):
def __init__(self):
self._lists = defaultdict(list)
def update(self, data):
for k, v in data.items():
self._lists[k].append(torch.linalg.norm(v.detach().reshape(-1)).item())
def get(self):
return dict(self._lists)
my_mcmc = StreamingMCMC(..., stats=[MeanAndVariance(), ListOfNorms()])
WDYT?
Addressing your earlier questions:
Also in test_mcmc_api.py additional tests should run both types of MCMC classes and compare final statistics.
Correct, most existing tests should be parametrized with
@pytest.markparametrize("mcmc_cls", [MCMC, StreamingMCMC])
As described, StreamingMCMC shouldn't support multiprocessing manually via _Worker because here CUDA, which is thought to be the main beneficiary of this new class, handles vectorization by itself. (is it correct?)
Almost. The main beneficiary here is large models which push against memory limits and therefore necessitate streaming rather than saving all samples in memory. And if you're pushing against memory limits, you'll want to avoid parallelizing and instead sequentially compute chains (which can itself be seen as a streaming operation). In practice yes most models that hit memory limits are run on GPU, but multicore CPU models can also be very performant.
Should StreamingMCMC have num_chains argument and for num_chains>1 just compute them sequentially or omit this argument?
StreamingMCMC should still support num_chains > 1 (which is valuable for determining convergence), but should compute them sequentially.
@mtsokol would you want to work on this in parallel? Maybe you could implement the StreamingMCMC class using hand-coded statistics, I could implement a basic pyro.ops.streaming module, and over the course of a few PRs we could meet in the middle?
@fritzo thanks for guidance! Right now I'm looking at the current implementation and starting working on this.
This abstraction with StreamingStatistic is sound to me. StreamingMCMC will only iterate and call method on passed objects implementing that interface.
Sure! I can start working on StreamingMCMC and already follow StreamingStatistic notion. When your RP is ready I will adjust my implementation.
Should I introduce some AbstractMCMC interface that existing MCMC and StreamingMCMC will implement?
Feel free to implement an AbstractMCMC interface if you like. I defer to your design judgement here.
@fritzo After thinking about handling those streamed samples I wanted to ask a few more questions:
-
So right now samples are being
yieldby sampler and each one is appended to the right chain list byz_flat_acc[chain_id].append(x_cloned). Then we do reshaping to get rid of the last dimension and have dict entries instead in that place (based on yielded structure). Then we perform element-wise transform (withself.transforms) (transform operation is determined by dict entry). Streaming based approach would go as follows: Again each sample is beingyieldby the sampler. The sample is used to construct a dict (based on yielded structure). Then that single dict is transformed (withself.transforms) and then the sample is fed to each statistic viaupdate(self, sample: Dict[str, torch.Tensor]). (So each single sample will result in constructing a new dict, is that OK?). WDYT? -
Should
StreamingStatisticupdate bechain_id-aware? Likeupdate(self, chain_id: int, sample: Dict[str, torch.Tensor])so that it can compute chain related diagnostics and supportgroup_by_chainargument? -
Why do we need to clone:
x_cloned = x.clone()whennum_chains > 1?
Follow up on the first question: If such a thing makes a performance difference (but I'm just wondering - it might be irrelevant) maybe instead of streaming each sample to statistics it can work in batches instead. E.g. introduce an additional argument batch_size=100 so StreamingMCMC would wait until it aggregates 100 samples, then constructs that dict and performs transformations and feeds the whole batch to statistics. (But maybe constructing a dict for each sample and transforming each sample separately isn't really an overhead - with ready implementation I can run memory and time measurements) WDYT?
@mtsokol answering your latest questions:
- tl;dr keep it simple.
I do not foresee a performance hit here: it is cheap to create dicts, and
StreamingMCMCwill typically be used with large memory-bound models with huge tensors, where the python overhead is negligible. For this same reason I think we should avoid batching since that increases memory overhead. (In fact I suspect the bottleneck will be in pyro.ops.streaming where we may need to refactor to perform tensor operations in-place). - Yes, I believe we will want to compute both per-chain and total-aggregated statistics. I have added a
.merge()operation in #2856 to make this easy for you. The main motivation is to compute cross-chain statistics like r_hat. - It looks like the cloning is explained earlier in the file. I would recommend keeping that logic.
https://github.com/pyro-ppl/pyro/blob/4a61ef2f9050ef81d1b0aa148d14ecc876f24a51/pyro/infer/mcmc/api.py#L389-L392
Hi @fritzo!
I was wondering what I can try to do next.
As Add r_hat to pyro.ops.streaming is completed I tried n_eff = ess for streaming but after short inspection of current implementation it looks undoable to me (as it requires e.g. those lags).
Apart from that I can definitely try:
Create a tutorial using StreamingMCMC on a big model
Could you suggest to me a problem with a model that would be suitable for that?
Also I can join new tutorial with your suggestion in the last bullet point in https://github.com/pyro-ppl/pyro/issues/2803#issuecomment-836644916 (showing how Predictive can be interchanged with poutine methods).
WDYT?
This would be a documentation task and I was also looking for an implementation one. Have you got something that I can try?