pymc
pymc copied to clipboard
Allow Truncation of CustomDist
This is now possible:
import pymc as pm
import numpy as np
def maxwell_dist(scale, size):
return pm.math.sqrt(pm.ChiSquared.dist(nu=3, size=size)) * scale
scale = 5.0
x = pm.CustomDist.dist(scale, dist=maxwell_dist)
trunc_x = pm.Truncated.dist(x, lower=0, upper=2, size=(1000,))
assert np.all(pm.draw(trunc_x) < 2)
trunc_x = pm.Truncated.dist(x, lower=0, upper=5, size=())
assert pm.logp(trunc_x, 3.0).eval() > pm.logp(x, 3.0).eval()
This required cleaning up the interface of SymbolicRandomVariable
s (mainly circumventing https://github.com/pymc-devs/pytensor/issues/473) so that we can safely "box" the base RVs in the inner OpFromGraph
(i.e., recreate them with new shared inputs).
This challenge is very specific to Truncated
which needs to "resample" the base RV for the rejection based algorithm.
No other SymbolicRandomVariable
needs to do this, and they have avoided the need to box the base RVs by simply resizing them to the total size and using the resized RVs as explicit inputs to the inner graph.
For instance, Mixture
will resize the component RVs to the "total size" and then scholastically index them based on its internal Categorical RV. ZeroSumNormal will create Normals as inputs and simply subtraction the mean.
Such an approach, however, makes it tricky for Truncated to know exactly what constitutes the "true" inputs of underlying SymbolicRandomVariable
s, and for this reason it rejected and still rejects arbitrary SymbolicRandomVariables. The exception, are the SymbolicRandomVariable
s created via CustomDist
because for those are already "pre-boxed" in a sense. We know the relevant graph must start at dist.owner.inputs
. Now that our class can safely manage and replace shared RNGs inputs, we can allow Truncated to handle such RVs, even if they require a bunch of shared RNGs.
Related to https://github.com/pymc-devs/pymc/discussions/6905#discussioncomment-7002111
TODO
- [x] This is marked as draft because it includes commit from #6923
- [x] Failing due to https://github.com/pymc-devs/pytensor/pull/475
Codecov Report
Attention: Patch coverage is 91.08280%
with 14 lines
in your changes are missing coverage. Please review.
Project coverage is 87.75%. Comparing base (
abe7bc9
) to head (374e4e3
).
:exclamation: Current head 374e4e3 differs from pull request most recent head 92efe38. Consider uploading reports for the commit 92efe38 to get more accurate results
Additional details and impacted files
@@ Coverage Diff @@
## main #6947 +/- ##
==========================================
- Coverage 92.30% 87.75% -4.55%
==========================================
Files 100 100
Lines 16888 16958 +70
==========================================
- Hits 15588 14882 -706
- Misses 1300 2076 +776
Files | Coverage Δ | |
---|---|---|
pymc/distributions/timeseries.py | 94.40% <100.00%> (-0.18%) |
:arrow_down: |
pymc/distributions/truncated.py | 99.44% <100.00%> (+0.03%) |
:arrow_up: |
pymc/distributions/distribution.py | 95.58% <91.54%> (+1.29%) |
:arrow_up: |
pymc/pytensorf.py | 90.71% <66.66%> (-0.58%) |
:arrow_down: |
This PR depends on #7227
Looks like some changes still needed till tests pass.
Looks like some changes still needed till tests pass.
See my comment above, it needs the pytensor dependency bump which is happening in a separate PR
Oh ok. I know it's not the most essential for this PR, but why does the Scan require a shared variable as an argument?
On Thu, 28 Mar 2024, 22:35 Ricardo Vieira, @.***> wrote:
Looks like some changes still needed till tests pass.
See my comment above, it needs the pytensor dependency bump which is happening in a separate PR
— Reply to this email directly, view it on GitHub https://github.com/pymc-devs/pymc/pull/6947#issuecomment-2026171253, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAACCUMISXKP3U4BVTYFK3LY2R5JXAVCNFSM6AAAAAA537NGBKVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAMRWGE3TCMRVGM . You are receiving this because your review was requested.Message ID: @.***>
Oh ok. I know it's not the most essential for this PR, but why does the Scan require a shared variable as an argument?
It's a limitation in the original implementation of Scan, where RNG variables must be shared. It's one of the things that we are hoping to solve with https://github.com/pymc-devs/pytensor/pull/191
General comment: What do you think about a warning if the
Truncated
distribution has to fall back to rejection sampling? This will introduce awhile
scan into the graph that could be quite surprising to users (JAX mode no longer possible, potentially big performance hit)
I wouldn't add a warning, because there is nothing for the user to do instead. Can add a note in the docstrings if there's no mention of it yet