pymc icon indicating copy to clipboard operation
pymc copied to clipboard

FlatSwitch Op for logprob derivation of arbitrary censoring

Open shreyas3156 opened this issue 2 years ago • 3 comments

What is this PR about? This PR defines a FlatSwitch Op that aims to extract the intervals and their respective encodings required to infer the logprob of arbitrary censored distributions. It achieves this in the following steps:

  1. Extract the intervals defined by the condition in pt.switch()recursively.
  2. Adjust/limit these intervals to eliminate the overlap from the outer switch.
  3. Identify the intervals that the true and false branches correspond to since each condition splits the space into two parts.

It then checks that we don't allow the broadcastability of a switch condition or any measurable branches, and if all the measurable components have the same source of measurability. The logic for these checks is based on #6834.

Once the intervals and their respective encodings are known, they can be used to calculate the log-probability. So, on running something like

rv2 = pt.switch(
    base_rv < -1,
    -1,
    pt.switch(
        base_rv < 1,  # -inf to 2, 2 to inf
        1,
        base_rv
    ),
)

we get something like:

lower: -1.0
upper: 1.0
encoding: 1 

lower: 1.0
upper: inf
encoding: normal_rv{0, (0, 0), floatX, False}.out 

TO-DO:

  • [x] The checks should work not only on pt.switch(x>0, x, a) but also on pt.switch(pt.exp(x)>0, pt.exp(x), b), where a and b are some encodings.
  • [x] In the FlatSwitch Op, add base_rv, intervals list and the corresponding encodings as inputs to the node so that they can be unpacked in the logprob calculation.

Checklist

@ricardoV94 @larryshamalama

shreyas3156 avatar Oct 12 '23 06:10 shreyas3156

Codecov Report

Attention: Patch coverage is 17.91045% with 110 lines in your changes are missing coverage. Please review.

Project coverage is 87.21%. Comparing base (244fb97) to head (b5f26a4).

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6949      +/-   ##
==========================================
- Coverage   92.26%   87.21%   -5.06%     
==========================================
  Files         100      100              
  Lines       16880    17009     +129     
==========================================
- Hits        15574    14834     -740     
- Misses       1306     2175     +869     
Files Coverage Δ
pymc/logprob/censoring.py 34.23% <17.91%> (-61.47%) :arrow_down:

... and 21 files with indirect coverage changes

codecov[bot] avatar Oct 12 '23 07:10 codecov[bot]

@shreyas3156 could you solve the conflicts issue? I'll finally review this one :)

ricardoV94 avatar Mar 04 '24 10:03 ricardoV94

One of the pre-existing tests is failing, not sure if due to the changes but would guess so?

https://github.com/pymc-devs/pymc/actions/runs/8169791795/job/22334588152?pr=6949#step:7:478

ricardoV94 avatar Mar 13 '24 09:03 ricardoV94