tests: switch to a bi-modal posterior for testing
We are testing our inference methods using a linear Gaussian simulator with a Gaussian target posterior.
Should be switch to a bimodal posterior instead to detect possible shortcomings of inference methods, e.g., the collapse of the MAF architecture to fitting a Gaussian when used with 1D input dimension?
Pro: more testing power Con: slower tests because we need more simulations (fast) and longer training (slow)?
Task for this issue:
- come up with a tractable and easy to understand bi-modal posterior (mixture of Gaussian?)
- check whether this results in slower tests by comparing it to the tests in
linearGaussian_snpe_test.py
Hello, I am currently working on that issue #1430 in the context of hackathon. With @plcrodrigues, we propose the following context to have bimodal posterior :
- dim = 2
- prior $p(\theta) = N((1,1), Id)$
- likelihood $p(x|\theta) = 0.5 N(\theta, \sigma^2Id) + 0.5 N(-\theta, \sigma^2Id)$
- posterior $p(\theta|x) = 0.5 N(\frac{x}{\sigma^2+1}, \frac{\sigma^2}{\sigma^2+1} Id) + 0.5 N(\frac{-x}{\sigma^2+1}, \frac{\sigma^2}{\sigma^2+1} Id)$
For the moment, I have tested this example on NPE method with 1 round : the c2st is not close to 0.5 but the density estimator manages to target the two modes of the posterior.
Does this framework seem relevant to you ? Thanks !
As an example, when I run the test, the density estimator finds the two modes (blue scatters in the figure) but the c2st is 0.67
Thanks for tackling this @etouron1
the simulator and resulting bi-modal posterior you suggested makes sense 👍
I assume the test is run with the vanilla NPE? Yes, it looks reasonable except the bunch of samples placed in the upper list region of the space, which causes C2ST to fail as it is quite sensitive to such outliers. In this case though, it fails correctly I believe, as we don't wont such outliers in the posterior.
Do these outliers disappear with more training simulations? Which density estimator used here?
If it is difficult to get the c2st down we could think about a different metric, like dkl or nltp and find threshold based on cases where we now it is accurate.