Add max_sampling_time support to rejection samplers and corresponding tests
Add max_sampling_time support to rejection samplers and corresponding tests
Summary
This PR adds an optional max_sampling_time argument to both rejection_sample() and accept_reject_sample() to prevent extremely long or infinite rejection-sampling loops in cases of heavy leakage or very low acceptance rates.
- This enhancement directly addresses Issue #1699
When the acceptance rate is extremely low (e.g., due to leakage or a poorly matched proposal), rejection sampling can stall indefinitely. This becomes problematic when sbi is used inside larger loops (hyperparameter sweeps, clusters, parallel runs), where a single stalled run blocks the entire job.
To make rejection sampling safer and more predictable, we introduce a maximum time budget after which sampling aborts with a clear and actionable RuntimeError.
What this PR adds
New argument
- Both samplers now accept
max_sampling_time: Optional[float] = None - When
None(default): behavior is unchanged. - When set: sampling aborts if the total time exceeds this threshold.
Time tracking
- A
start_time = time.time()is recorded before the loop, and checked at each iteration.
Safe early abort
- If the timeout triggers, both functions raise a
RuntimeErrorwith guidance on how to handle slow sampling (e.g. switching to MCMC or VI).
Tests tests/rejection_timeout_test.py
- Normal sampling still works.
- Timeout is triggered when acceptance is ~0.
- Both sampler functions correctly raise
RuntimeError. - No regression in the existing sampling behavior.
Backwards compatibility
- Existing code using rejection sampling continues to work as before.
- Timeout is opt-in.
- No breaking API changes.
Related Issue
- Closes #1699
Checklist
- [x] Add
max_sampling_timeto both functions - [x] Add timeout checks + early abort
- [x] Add comprehensive tests
- [x] Ensure no behavioral changes when timeout is
None - [x] Passes
pytestandpre-commithooks
Codecov Report
:x: Patch coverage is 68.96552% with 9 lines in your changes missing coverage. Please review.
:warning: Please upload report for BASE (main@2c216c2). Learn more about missing BASE report.
:warning: Report is 2 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #1705 +/- ##
=======================================
Coverage ? 84.68%
=======================================
Files ? 137
Lines ? 11508
Branches ? 0
=======================================
Hits ? 9745
Misses ? 1763
Partials ? 0
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 84.68% <68.96%> (?) |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Files with missing lines | Coverage Δ | |
|---|---|---|
| sbi/inference/posteriors/rejection_posterior.py | 67.85% <100.00%> (ø) |
|
| sbi/samplers/rejection/rejection.py | 95.23% <87.50%> (ø) |
|
| sbi/inference/posteriors/direct_posterior.py | 83.62% <66.66%> (ø) |
|
| sbi/inference/posteriors/vector_field_posterior.py | 69.65% <50.00%> (ø) |
Thanks @satwiksps looks very good!
Just made suggestions for the error message.
Additionally, I am tending towards introducing an actual time out by default, e.g., 3600s or so? What do you think, also @gmoss13 and @michaeldeistler ?
Thanks @janfb, I agree that we should give a default time, but I would make it even smaller, something like 600s or even 300s. In my experience, we either have posteriors that can easily generate valid samples and the rejection sampling is very quick, or posteriors that have very low acceptance rates.
Hi @satwiksps! We had a quick discussion now about what behaviour we exactly want to have after this PR, and these are the resolutions:
- The default
max_sampling_timeshould beNoneas you initially had it - When the user sets a
max_sampling_time, if it is exceeded, aRuntimeErrorshould be raised. - As @michaeldeistler suggested in #1699, we could add a
reject_outside_priorflag to theposterior.sample()method that we can set to bool to just return samples without rejection sampling. - When the
max_sampling_timeisNone, we should still have a warning, but we should also add an informative statement to the user about the options to turn off rejection sampling or set a maximum_sampling_time.
How does this sound for you,? Are you happy to add these things to your PR, or is there anything to clarify?