sbi icon indicating copy to clipboard operation
sbi copied to clipboard

Add max_sampling_time support to rejection samplers and corresponding tests

Open satwiksps opened this issue 1 month ago • 3 comments

Add max_sampling_time support to rejection samplers and corresponding tests

Summary

This PR adds an optional max_sampling_time argument to both rejection_sample() and accept_reject_sample() to prevent extremely long or infinite rejection-sampling loops in cases of heavy leakage or very low acceptance rates.

  • This enhancement directly addresses Issue #1699

When the acceptance rate is extremely low (e.g., due to leakage or a poorly matched proposal), rejection sampling can stall indefinitely. This becomes problematic when sbi is used inside larger loops (hyperparameter sweeps, clusters, parallel runs), where a single stalled run blocks the entire job.

To make rejection sampling safer and more predictable, we introduce a maximum time budget after which sampling aborts with a clear and actionable RuntimeError.

What this PR adds

New argument

  • Both samplers now accept max_sampling_time: Optional[float] = None
  • When None (default): behavior is unchanged.
  • When set: sampling aborts if the total time exceeds this threshold.

Time tracking

  • A start_time = time.time() is recorded before the loop, and checked at each iteration.

Safe early abort

  • If the timeout triggers, both functions raise a RuntimeError with guidance on how to handle slow sampling (e.g. switching to MCMC or VI).

Tests tests/rejection_timeout_test.py

  • Normal sampling still works.
  • Timeout is triggered when acceptance is ~0.
  • Both sampler functions correctly raise RuntimeError.
  • No regression in the existing sampling behavior.

Backwards compatibility

  • Existing code using rejection sampling continues to work as before.
  • Timeout is opt-in.
  • No breaking API changes.

Related Issue

  • Closes #1699

Checklist

  • [x] Add max_sampling_time to both functions
  • [x] Add timeout checks + early abort
  • [x] Add comprehensive tests
  • [x] Ensure no behavioral changes when timeout is None
  • [x] Passes pytest and pre-commit hooks

satwiksps avatar Nov 17 '25 16:11 satwiksps

Codecov Report

:x: Patch coverage is 68.96552% with 9 lines in your changes missing coverage. Please review. :warning: Please upload report for BASE (main@2c216c2). Learn more about missing BASE report. :warning: Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
sbi/inference/posteriors/vector_field_posterior.py 50.00% 6 Missing :warning:
sbi/inference/posteriors/direct_posterior.py 66.66% 2 Missing :warning:
sbi/samplers/rejection/rejection.py 87.50% 1 Missing :warning:
Additional details and impacted files
@@           Coverage Diff           @@
##             main    #1705   +/-   ##
=======================================
  Coverage        ?   84.68%           
=======================================
  Files           ?      137           
  Lines           ?    11508           
  Branches        ?        0           
=======================================
  Hits            ?     9745           
  Misses          ?     1763           
  Partials        ?        0           
Flag Coverage Δ
unittests 84.68% <68.96%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
sbi/inference/posteriors/rejection_posterior.py 67.85% <100.00%> (ø)
sbi/samplers/rejection/rejection.py 95.23% <87.50%> (ø)
sbi/inference/posteriors/direct_posterior.py 83.62% <66.66%> (ø)
sbi/inference/posteriors/vector_field_posterior.py 69.65% <50.00%> (ø)

codecov[bot] avatar Nov 17 '25 17:11 codecov[bot]

Thanks @satwiksps looks very good!

Just made suggestions for the error message.

Additionally, I am tending towards introducing an actual time out by default, e.g., 3600s or so? What do you think, also @gmoss13 and @michaeldeistler ?

Thanks @janfb, I agree that we should give a default time, but I would make it even smaller, something like 600s or even 300s. In my experience, we either have posteriors that can easily generate valid samples and the rejection sampling is very quick, or posteriors that have very low acceptance rates.

gmoss13 avatar Nov 18 '25 08:11 gmoss13

Hi @satwiksps! We had a quick discussion now about what behaviour we exactly want to have after this PR, and these are the resolutions:

  • The default max_sampling_time should be None as you initially had it
  • When the user sets a max_sampling_time, if it is exceeded, a RuntimeError should be raised.
  • As @michaeldeistler suggested in #1699, we could add a reject_outside_prior flag to the posterior.sample() method that we can set to bool to just return samples without rejection sampling.
  • When the max_sampling_time is None, we should still have a warning, but we should also add an informative statement to the user about the options to turn off rejection sampling or set a maximum_sampling_time.

How does this sound for you,? Are you happy to add these things to your PR, or is there anything to clarify?

gmoss13 avatar Nov 20 '25 10:11 gmoss13