pyro
pyro copied to clipboard
Implementation of quantiles for messenger guides [WIP]
This PR contributes the implementation of quantiles for messenger guides - which resembles how this is done in AutoNormal. I see certain errors that I don't know how to address (https://github.com/pyro-ppl/pyro/runs/4490764371?check_suite_focus=true#step:5:22410) - namely, the messenger guide tries to expand returned quantiles new.v = self.v.expand(batch_shape + self.event_shape). This can potentially be addressed by running a for loop over requested quantiles like this:
def quantiles(self, quantiles, *args, **kwargs):
self._computing_quantiles = True
try:
for i, q in enumerate(quantiles):
self._quantile_values = q
_quantiles_1 = self(*args, **kwargs)
if i == 0:
_quantiles = {k: [v] for k, v in _quantiles_1.items()}
else:
_quantiles_1
# add new quantile
_quantiles = {k: _quantiles[k] + [_quantiles_1[k]] for k in _quantiles.keys()}
return {k: torch.tensor(v, dtype=_quantiles_1[k].dtype, device=_quantiles_1[k].device) for k, v in _quantiles.items()}
finally:
self._computing_quantiles = False
Please let me know what you think @fritzo
Is it possible to restart the checks? Maybe the issue resolved itself with new changes to pyro/pytorch?
not sure how to restart in github actions, but you could merge Pyro dev branch and push, that should rerun ci