sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Implement SNPE-B

Open michaeldeistler opened this issue 4 years ago • 7 comments

SNPE-B is currently not implemented.

When implementing it take care with the following:

  • we evaluate the importance weights always on the last posterior, no matter what round the samples are actually coming from.
  • to be checked: does the posterior in the importance weights really not have gradients that are updated anymore?

michaeldeistler avatar Jun 02 '20 14:06 michaeldeistler

@michaeldeistler Did this get fixed?

jan-matthis avatar Jun 18 '20 12:06 jan-matthis

No, not fixed yet. SNPE-B raises a NotImplementedError.

michaeldeistler avatar Jun 18 '20 12:06 michaeldeistler

Hi, Is SNPE-B is available in sbi? I just searched in documentation and found no match. cheers.

Ziaeemehr avatar Sep 30 '21 07:09 Ziaeemehr

no it's not available yet

michaeldeistler avatar Sep 30 '21 07:09 michaeldeistler

More context: At the moment, we only implement multi-round NPE via

  • SNPE-A with post-hoc correction,
  • and SNPE-C based on learning the proposal posterior.

SNPE-B enables multi-round NPE using an importance weighted loss. Although the with importance weighs can be unstable, SNPE-B can still be useful for some applications and would be "nice to have".

Link to paper: https://proceedings.neurips.cc/paper/2017/hash/addfa9b7e234254d26e9c7f2af1005cb-Abstract.html

rough outline for a PR:

  • extend pseudo code in https://github.com/sbi-dev/sbi/blob/main/sbi/inference/snpe/snpe_b.py
  • add tests in https://github.com/sbi-dev/sbi/blob/3aeb7756783867c30942802118b449b1c5739824/tests/linearGaussian_snpe_test.py#L44

janfb avatar Feb 27 '24 15:02 janfb

Hello, I'm trying to implement SNPE-B with @touronc and @plcrodrigues as suggested in Issue #199. According to the paper Flexible statistical inference for mechanistic models of neural dynamics, Lueckmann, Gonçalves et al., NeurIPS 2017, during the training phase, the loss should be weighted by $\frac{p(\theta)}{\tilde{p}(\theta)}$ for $\theta$ sampled from the proposal prior.

But in the sbi code, when we consider multiple rounds (sequential NPE), the method self.get_dataloaders of the class NeuralInference creates shuffled batches of $\theta$'s coming from the prior (first round) or from the prior and from the proposal (next rounds). In some cases, the prior has heavier tail than the proposal : so for some $\theta$'s coming from the prior, $\tilde{p}(\theta)$ is almost $0$ whereas $p(\theta)$ is strictly positive. Thus the importance weight $\frac{p(\theta)}{\tilde{p}(\theta)}$ is not well defined and throws $inf$ values. To illustrate this problem, we use the package sbibm with the task gaussian_mixture : in this task, the prior is uniform between $-10$ and $10$, and $\theta$ has $2$ dimensions. Here is a case where the problem may happen (we represent the second dimension only): if $\theta$ is sampled from the prior between $5$ and $10$, it will have $\tilde{p}(\theta) \approx 0$ and the ratio won't be well defined.

Screenshot from 2024-04-03 11-10-33

To avoid this problem, we propose the following solution : thanks to the masks_batch contained in the batches of the train_loader and val_loader, we can know whether each $\theta$ of the batch comes from the prior or not. So for $\theta$ in the batch coming from the proposal, we compute the importance weight as suggested in the paper, and for $\theta$ coming from the prior, we set the importance weight to $1$.

Here is a piece of the code :

        log_prob_prior = torch.zeros(theta.size(0))
        log_prob_proposal = torch.zeros(theta.size(0))

        # Evaluate prior
        log_prob_prior[torch.logical_not(masks.squeeze())] = self._prior.log_prob(theta[torch.logical_not(masks.squeeze()),:])
        utils.assert_all_finite(log_prob_prior, "prior eval.")

        # Evaluate proposal.
        log_prob_proposal[torch.logical_not(masks.squeeze())] = proposal.log_prob(theta[torch.logical_not(masks.squeeze()),:])
        utils.assert_all_finite(log_prob_proposal, "proposal posterior eval")
        
        # Compute the importance weights.
        importance_weights = torch.exp(log_prob_prior-log_prob_proposal)

This method seems more stable as we try on other examples fromsbibm.

As you know, there still remains the potential problem of having $\theta$ samples from the tail of the proposal, in which case the ratio may also be unstable ($\tilde{p}(\theta) \approx 0$) . In our experiments with gaussian_mixture, we encountered this case : some $\theta$ are on the tail of the proposal as shown below. However, for the moment, we still have finite values for the ratio.

Screenshot from 2024-04-03 11-31-40 Screenshot from 2024-04-03 11-28-35

Question: Do you know if anyone else is working on SNPE-B as well ? Do you think the solution proposed above could be interesting ?

etouron1 avatar Apr 03 '24 13:04 etouron1

Thanks a lot for looking into this and adding the detailed report 👍

I am tagging @jan-matthis because he probably encountered similar implementation issues in delfi back then.

See also https://github.com/mackelab/delfi/blob/main/delfi/inference/SNPEB.py

janfb avatar Apr 03 '24 14:04 janfb