add: prototype implementation of tarp in sbi
What does this implement/fix? Explain your changes
Tarp is a diagnotics method, which can help identify over-/underdispersion and bias in trained neural posteriors. The corresponding paper is located here: https://arxiv.org/abs/2302.03026
the repo code for numpy is here: https://github.com/Ciela-Institute/tarp/
Does this close any currently open issues?
No, this was part of the Mar 2024 SBI hackathon in Tübingen
Any relevant code examples, logs, error output, etc?
Not yet, I am trying to reproduce the examples given in the paper. At a later point in time, I'd like to bring the tests as well as a tutorial in line to what is available with sbc.
Checklist
Put an x in the boxes that apply. You can also fill these out after creating
the PR. If you're unsure about any of them, don't hesitate to ask. We're here to
help! This is simply a reminder of what we are going to look for before merging
your code.
- [x] I have read and understood the contribution guidelines
- [x] I agree with re-licensing my contribution from AGPLv3 to Apache-2.0.
- [x] I have commented my code, particularly in hard-to-understand areas
- [x] I have added tests that prove my fix is effective or that my feature works
- [x] I have reported how long the new tests run and potentially marked them
with
pytest.mark.slow. - [x] New and existing unit tests pass locally with my changes
- [x] I performed linting and formatting as described in the contribution guidelines
- [x] I rebased on
main(or there are no conflicts withmain)
Note, I am stopping work on this PR for the time being. I ran into issues reproducing the tarp paper: https://github.com/Ciela-Institute/tarp/issues/8 Once they are resolved, I'll continue working on this.
Dear @janfb and @JuliaLinhart,
an alpha version of TARP (arxiv) as a SBI diagnostic is now ready from my point of view. I'd love someone of you to have a look. There are two files that I added sbi/diagnostics/tarp.py and tests/tarp_tests.py. The last unit test also documents how tarp would be used with SBI posterior predictions.
Feel free to have a look.
I'd have some questions though:
-
at this point, the tarp coverage estimates are returned as raw numbers, i.e. I don't perform any hypothesis testing on them, should I add (i.e. a KS test) that?
-
the
TARPdiagnostic class currently implements arunand acheckfunction (to be aligned with the SBC code); therunfunction practically doesn't do anything TARP related but rather draws samples from the posterior,checkactually performs tarp without any hypothesis test. I'm unclear if we should rather haverunto compute the coverage stats andcheckdo the hypothesis test for example. What do you think? -
the TARP paper also offers a bootstrapped version of the diagnostic, would we want to have that in SBI too?
-
I think, if TARP is included in SBI, there should be a tutorial about it. I'd rather not make this part of this PR though. Is that OK?
@psteinb just checking in here. Do you plan to continue working on this soon so get it into the upcoming release? We could also schedule a video call to discuss my points, which could be more efficient then sending another round of comments back and forth?
Sure thing, will contact you directly on this.
As just discussed:
- short term goal for this PR is to push TARP to be available in a way that sbc is in sbi
- no progress bars
- try to reuse plotting functions as much as possible that already exist for sbc
@janfb I refactored the tarp implementation as discussed offline. Feel free to have a look again. Should I add a small tutorial notebook?
Codecov Report
Attention: Patch coverage is 56.84211% with 41 lines in your changes missing coverage. Please review.
Project coverage is 75.55%. Comparing base (
6f61662) to head (dfcf3a2). Report is 1 commits behind head on main.
:exclamation: There is a different number of reports uploaded between BASE (6f61662) and HEAD (dfcf3a2). Click for more details.
HEAD has 1 upload more than BASE
| Flag | BASE (6f61662) | HEAD (dfcf3a2) | |------|------|------| |unittests|1|2|
Additional details and impacted files
@@ Coverage Diff @@
## main #1106 +/- ##
==========================================
- Coverage 82.13% 75.55% -6.59%
==========================================
Files 93 94 +1
Lines 7458 7571 +113
==========================================
- Hits 6126 5720 -406
- Misses 1332 1851 +519
| Flag | Coverage Δ | |
|---|---|---|
| unittests | 75.55% <56.84%> (-6.59%) |
:arrow_down: |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Files | Coverage Δ | |
|---|---|---|
| sbi/utils/metrics.py | 40.33% <100.00%> (-53.46%) |
:arrow_down: |
| sbi/diagnostics/tarp.py | 54.94% <54.94%> (ø) |
Merging this into main for now. The method is working with similar API like the SBC method.
What is missing and post-poned to a future PR:
- unify API with SBC
- refactor to have most of the logic in a
run_tarpmethod. maybe there is a away to combine it with SBC even - make use of batched sampling
- add explanation and code to Diagnostics tutorial
The next PR could come with a bigger refactoring for the SBC methods as well.
Thanks so much for the final commits to push this across the finish line.
Hi @psteinb ! Sorry, I really didn't have the time to look at this PR for now! But I think it's great you took the time to look into and implement this diagnostic! Thank you! I will have a look at it as soon as I can!
No problem @JuliaLinhart and thanks for your kind words. I am happy that TARP is part of SBI now. It is a global method and quite fast. But the details are a bit hard to grasp from my point of view.