Set maximum time for rejection sampling
🚀 Feature Request
Is your request related to a problem?
Almost all posterior sample with accept_reject_sample. This raises a warning that sampling will be slow if the acceptance rate is too low, but does not stop trying to sample with rejection.
This is fine when running sbi locally, e.g. in a notebook, but means you have to work around it if sbi is part of a bigger loop (e.g. when optimizing over hyperparameters @swag2198 or running many sbi runs in parallel on a cluster). Then your job is stuck and does not output anything, as it just runs until you stop it or the job times out.
Describe the solution you'd like
We should add an argument, e.g. max_sampling_time, after which we give up on rejection sampling. We can then raise an Error, which should include instructions on how to catch the error to define custom behaviour in this case (e.g. sample posterior without restriction or return prior samples, i.e. what to do in a "failed" run).
...
Describe alternatives you've considered
We can leave the code as it is and add a tutorial/faq entry.
...
Hi! @gmoss13 @janfb I would like to work on this issue.
My plan for the fix is as follows:
- Add a new optional argument (e.g.,
max_sampling_time) to the rejection sampling routine (accept_reject_sample). - Track the sampling start time and periodicaly check whether the elapsed time exceeds the user-specified limit.
- If the maximum time is exceeded, raise a clear
RuntimeErrorwith guidance on how users can catch the error and implement a fallback (e.g., unrestricted posterior sampling or prior sampling). - Add corresponding tests to ensure the sampler stops as expected when
max_sampling_timeis reached.
Let me know if you have suggestions or want the behavior to be slightly different. I’ll begin working on It.
Hi @satwiksps!
Glad you want to work on this issue, and happy for you to make a PR on this. Your suggested fix seems reasonable to me.
Yes, sounds good!
In general, @satwiksps , Kudos for offering your help with many different issues here, that's great!
Your plans and PRs are very structured and well-reviewable - much appreciated! 🙏
Sorry, only diving into this now: I think even more useful than a max_sampling_time would be a flag reject_outside_prior_support: bool. Using this flag, the user could then manually check which samples to accept via
samples = posterior.sample((100,), x=x_o, reject_outside_prior_support=False)
accept = prior.support.check(samples)
samples_accepted = samples[accept]
EDIT: thinking more about it, I am not so sure anymore if I prefer the above over max_sampling_time.
Sorry, only diving into this now: I think even more useful than a
max_sampling_timewould be a flagreject_outside_prior_support: bool. Using this flag, the user could then manually check which samples to accept viasamples = posterior.sample((100,), x=x_o, reject_outside_prior_support=False) accept = prior.support.check(samples) samples_accepted = samples[accept] EDIT: thinking more about it, I am not so sure anymore if I prefer the above over
max_sampling_time.
I see the point of more flexibility here, but technically, it would be a bit fuzzy to sample from a posterior while allowing for samples outside the prior support, no? If the user really wants this flexibility they could always sample from the density estimator object directly, right? (we could add a how-to-guide for that).
Yes, they could just sample the density_estimator. However, it has a different API, so is a bit less beginner-friendly.
I think the issue of "no accepted samples" happens to so many users that we should consider having the reject_outside_prior_support flag, just to make it very easy for new users to see what is happening. We could even add this to the "sampling is slow" warning.
I see your point @michaeldeistler.
I still think the max_sampling_time is the better way to go, because given a warning, some script would need to be rerun after changing the flag to reject_outside_prior_support = False, whereas if we get an Error after some time, the user can catch it and define some custom behaviour. It makes it a bit easier in my opinion to write reproducible code where you are running training many inference objects. On the other hand, for exploring/training in notebooks, probably a flag for reject_outside_prior_support is more convenient. So could we even do both?