pymc
pymc copied to clipboard
FlatSwitch Op for logprob derivation of arbitrary censoring
What is this PR about? This PR defines a FlatSwitch Op that aims to extract the intervals and their respective encodings required to infer the logprob of arbitrary censored distributions. It achieves this in the following steps:
- Extract the intervals defined by the condition in
pt.switch()recursively. - Adjust/limit these intervals to eliminate the overlap from the outer switch.
- Identify the intervals that the true and false branches correspond to since each condition splits the space into two parts.
It then checks that we don't allow the broadcastability of a switch condition or any measurable branches, and if all the measurable components have the same source of measurability. The logic for these checks is based on #6834.
Once the intervals and their respective encodings are known, they can be used to calculate the log-probability. So, on running something like
rv2 = pt.switch(
base_rv < -1,
-1,
pt.switch(
base_rv < 1, # -inf to 2, 2 to inf
1,
base_rv
),
)
we get something like:
lower: -1.0
upper: 1.0
encoding: 1
lower: 1.0
upper: inf
encoding: normal_rv{0, (0, 0), floatX, False}.out
TO-DO:
- [x] The checks should work not only on
pt.switch(x>0, x, a)but also onpt.switch(pt.exp(x)>0, pt.exp(x), b), where a and b are some encodings. - [x] In the FlatSwitch Op, add
base_rv, intervals list and the corresponding encodings as inputs to the node so that they can be unpacked in the logprob calculation.
Checklist
- [x] Explain important implementation details 👆
- [x] Make sure that the pre-commit linting/style checks pass.
- [x] Link relevant issues (preferably in nice commit messages)
- [ ] Are the changes covered by tests and docstrings?
- [x] Fill out the short summary sections 👇
@ricardoV94 @larryshamalama
Codecov Report
Attention: Patch coverage is 17.91045% with 110 lines in your changes are missing coverage. Please review.
Project coverage is 87.21%. Comparing base (
244fb97) to head (b5f26a4).
Additional details and impacted files
@@ Coverage Diff @@
## main #6949 +/- ##
==========================================
- Coverage 92.26% 87.21% -5.06%
==========================================
Files 100 100
Lines 16880 17009 +129
==========================================
- Hits 15574 14834 -740
- Misses 1306 2175 +869
| Files | Coverage Δ | |
|---|---|---|
| pymc/logprob/censoring.py | 34.23% <17.91%> (-61.47%) |
:arrow_down: |
@shreyas3156 could you solve the conflicts issue? I'll finally review this one :)
One of the pre-existing tests is failing, not sure if due to the changes but would guess so?
https://github.com/pymc-devs/pymc/actions/runs/8169791795/job/22334588152?pr=6949#step:7:478