sbi
sbi copied to clipboard
Implement SNPE-B
SNPE-B is currently not implemented.
When implementing it take care with the following:
- we evaluate the importance weights always on the last posterior, no matter what round the samples are actually coming from.
- to be checked: does the posterior in the importance weights really not have gradients that are updated anymore?
@michaeldeistler Did this get fixed?
No, not fixed yet. SNPE-B raises a NotImplementedError
.
Hi, Is SNPE-B is available in sbi? I just searched in documentation and found no match. cheers.
no it's not available yet
More context: At the moment, we only implement multi-round NPE via
- SNPE-A with post-hoc correction,
- and SNPE-C based on learning the proposal posterior.
SNPE-B enables multi-round NPE using an importance weighted loss. Although the with importance weighs can be unstable, SNPE-B can still be useful for some applications and would be "nice to have".
Link to paper: https://proceedings.neurips.cc/paper/2017/hash/addfa9b7e234254d26e9c7f2af1005cb-Abstract.html
rough outline for a PR:
- extend pseudo code in https://github.com/sbi-dev/sbi/blob/main/sbi/inference/snpe/snpe_b.py
- add tests in https://github.com/sbi-dev/sbi/blob/3aeb7756783867c30942802118b449b1c5739824/tests/linearGaussian_snpe_test.py#L44
Hello, I'm trying to implement SNPE-B with @touronc and @plcrodrigues as suggested in Issue #199. According to the paper Flexible statistical inference for mechanistic models of neural dynamics, Lueckmann, Gonçalves et al., NeurIPS 2017, during the training phase, the loss should be weighted by $\frac{p(\theta)}{\tilde{p}(\theta)}$ for $\theta$ sampled from the proposal prior.
But in the sbi
code, when we consider multiple rounds (sequential NPE), the method self.get_dataloaders
of the class NeuralInference
creates shuffled batches of $\theta$'s coming from the prior (first round) or from the prior and from the proposal (next rounds).
In some cases, the prior has heavier tail than the proposal : so for some $\theta$'s coming from the prior, $\tilde{p}(\theta)$ is almost $0$ whereas $p(\theta)$ is strictly positive. Thus the importance weight $\frac{p(\theta)}{\tilde{p}(\theta)}$ is not well defined and throws $inf$ values.
To illustrate this problem, we use the package sbibm
with the task gaussian_mixture
: in this task, the prior is uniform between $-10$ and $10$, and $\theta$ has $2$ dimensions.
Here is a case where the problem may happen (we represent the second dimension only): if $\theta$ is sampled from the prior between $5$ and $10$, it will have $\tilde{p}(\theta) \approx 0$ and the ratio won't be well defined.
To avoid this problem, we propose the following solution :
thanks to the masks_batch
contained in the batches of the train_loader
and val_loader
, we can know whether each $\theta$ of the batch comes from the prior or not.
So for $\theta$ in the batch coming from the proposal, we compute the importance weight as suggested in the paper, and for $\theta$ coming from the prior, we set the importance weight to $1$.
Here is a piece of the code :
log_prob_prior = torch.zeros(theta.size(0))
log_prob_proposal = torch.zeros(theta.size(0))
# Evaluate prior
log_prob_prior[torch.logical_not(masks.squeeze())] = self._prior.log_prob(theta[torch.logical_not(masks.squeeze()),:])
utils.assert_all_finite(log_prob_prior, "prior eval.")
# Evaluate proposal.
log_prob_proposal[torch.logical_not(masks.squeeze())] = proposal.log_prob(theta[torch.logical_not(masks.squeeze()),:])
utils.assert_all_finite(log_prob_proposal, "proposal posterior eval")
# Compute the importance weights.
importance_weights = torch.exp(log_prob_prior-log_prob_proposal)
This method seems more stable as we try on other examples fromsbibm
.
As you know, there still remains the potential problem of having $\theta$ samples from the tail of the proposal, in which case the ratio may also be unstable ($\tilde{p}(\theta) \approx 0$) .
In our experiments with gaussian_mixture
, we encountered this case : some $\theta$ are on the tail of the proposal
as shown below. However, for the moment, we still have finite values for the ratio.
Question: Do you know if anyone else is working on SNPE-B as well ? Do you think the solution proposed above could be interesting ?
Thanks a lot for looking into this and adding the detailed report 👍
I am tagging @jan-matthis because he probably encountered similar implementation issues in delfi
back then.
See also https://github.com/mackelab/delfi/blob/main/delfi/inference/SNPEB.py