Implement specialized MvNormal density based on precision matrix
Description
This PR is exploring a specialized logp for a MvNormal (and possible MvStudentT) parametrized directly in terms of tau. According to common model implementation looks like:
import pymc as pm
import numpy as np
A = np.array([
[0, 1, 1],
[1, 0, 1],
[1, 1, 0]
])
D = A.sum(axis=-1)
np.testing.assert_allclose(A, A.T), "should be symmetric"
with pm.Model() as m:
tau = pm.InverseGamma("tau", 1, 1)
alpha = pm.Beta("alpha", 10, 10)
Q = tau * (D - alpha * A)
y = pm.MvNormal("y", mu=np.zeros(3), tau=Q)
TODO (some are optional for this PR)
- [x] Benchmark to confirm it's more performant than old form where there's a matrix inverse and a cholesky decomposition (now we have a slogdet that does LU factorization under the hood)
- [x] Benchmark gradients as well
- [x] Check whether CAR is redundant with this?
- [x] Add tests
- [ ] Consider extending to MvStudentT
- [x] Explore whether something based on Cholesky decomposition still makes (more) sense here. CC @aseyboldt
- [x] ~~Sparse implementation? May need some ideas like: https://stackoverflow.com/questions/19107617/how-to-compute-scipy-sparse-matrix-determinant-without-turning-it-to-dense~~ Investigate in a follow up PR
Related Issue
- [ ] Related to https://github.com/pymc-devs/pymc-experimental/issues/340
Checklist
- [x] Checked that the pre-commit linting/style checks pass
- [ ] Included tests that prove the fix is effective or that the new feature works
- [ ] Added necessary documentation (docstrings and/or example notebooks)
- [ ] If you are a pro: each commit corresponds to a relevant logical change
Type of change
- [x] New feature / enhancement
- [ ] Bug fix
- [ ] Documentation
- [ ] Maintenance
- [ ] Other (please specify):
CC @theorashid @elizavetasemenova
π Documentation preview π: https://pymc--7345.org.readthedocs.build/en/7345/
Implementation checks may fail until https://github.com/pymc-devs/pytensor/issues/799 is fixed
Check out this pull request onΒ ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
Benchmark code
import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt
rng = np.random.default_rng(123)
n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)
with pm.Model(check_bounds=False) as m:
Q = pm.Data("Q", Q_test)
x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)
logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)
dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)
np.testing.assert_allclose(logp_fn(x_test), np.array(-1789.93662205))
np.testing.assert_allclose(np.sum(dlogp_fn(x_test) ** 2), np.array(18445204.8755109), rtol=1e-6)
# Before: 2.66 ms
# After: 1.31 ms
%timeit -n 1000 logp_fn(x_test)
# Before: 2.45 ms
# After: 72 Β΅s
%timeit -n 1000 dlogp_fn(x_test)
Codecov Report
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 92.20%. Comparing base (
8eaa9be) to head (8550f01). Report is 100 commits behind head on main.
Additional details and impacted files
@@ Coverage Diff @@
## main #7345 +/- ##
==========================================
+ Coverage 92.18% 92.20% +0.01%
==========================================
Files 103 103
Lines 17263 17301 +38
==========================================
+ Hits 15914 15952 +38
Misses 1349 1349
| Files with missing lines | Coverage Ξ | |
|---|---|---|
| pymc/distributions/multivariate.py | 93.10% <100.00%> (+0.25%) |
:arrow_up: |
| pymc/logprob/rewriting.py | 89.18% <100.00%> (+0.17%) |
:arrow_up: |
Final question is just whether we want / can do a similar thing for the MvStudentT. Otherwise it's ready to merge on my end
CC @elizavetasemenova
Last benchmarks, running the following script:
%env OMP_NUM_THREADS=1
USE_TAU = True
import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt
rng = np.random.default_rng(123)
n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)
with pm.Model(check_bounds=False) as m:
Q = pm.Data("Q", Q_test)
if USE_TAU:
x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)
else:
x = pm.MvNormal("x", mu=pt.zeros(n), cov=Q)
logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)
dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)
%timeit -n 10000 logp_fn(x_test)
%timeit -n 10000 dlogp_fn(x_test)
USE_TAU = TRUE, without optimization:
logp
Composite{((i2 - (i0 * i1)) - i3)} [id A] 'x_logprob' 9
ββ 0.5 [id B]
ββ DropDims{axis=0} [id C] 8
β ββ CAReduce{Composite{(i0 + sqr(i1))}, axis=1} [id D] 7
β ββ Transpose{axes=[1, 0]} [id E] 5
β ββ SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2} [id F] 3
β ββ Cholesky{lower=True, destructive=False, on_error='nan'} [id G] 2
β β ββ MatrixInverse [id H] 1
β β ββ Q [id I]
β ββ ExpandDims{axis=1} [id J] 0
β ββ x [id K]
ββ -91.89385332046727 [id L]
ββ CAReduce{Composite{(i0 + log(i1))}, axes=None} [id M] 6
ββ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id N] 4
ββ Cholesky{lower=True, destructive=False, on_error='nan'} [id G] 2
ββ Β·Β·Β·
dlogp
DropDims{axis=1} [id A] '(dx_logprob/dx)' 7
ββ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2} [id B] 6
ββ Transpose{axes=[1, 0]} [id C] 5
β ββ Cholesky{lower=True, destructive=False, on_error='nan'} [id D] 2
β ββ MatrixInverse [id E] 1
β ββ Q [id F]
ββ Neg [id G] 4
ββ SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2} [id H] 3
ββ Cholesky{lower=True, destructive=False, on_error='nan'} [id D] 2
β ββ Β·Β·Β·
ββ ExpandDims{axis=1} [id I] 0
ββ x [id J]
541 Β΅s Β± 56.8 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
503 Β΅s Β± 41.3 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
USE_TAU = True with optimization
logp
Composite{(i4 * ((i2 - (i0 * i1)) + i3))} [id A] 'x_logprob' 11
ββ 2.0 [id B]
ββ CAReduce{Composite{(i0 + log(i1))}, axes=None} [id C] 10
β ββ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id D] 9
β ββ Cholesky{lower=True, destructive=False, on_error='raise'} [id E] 8
β ββ Q [id F]
ββ 183.78770664093454 [id G]
ββ DropDims{axis=0} [id H] 7
β ββ CGemv{inplace} [id I] 6
β ββ AllocEmpty{dtype='float64'} [id J] 5
β β ββ 1 [id K]
β ββ 1.0 [id L]
β ββ ExpandDims{axis=0} [id M] 4
β β ββ x [id N]
β ββ CGemv{inplace} [id O] 3
β β ββ AllocEmpty{dtype='float64'} [id P] 2
β β β ββ Shape_i{1} [id Q] 1
β β β ββ Q [id F]
β β ββ 1.0 [id L]
β β ββ Transpose{axes=[1, 0]} [id R] 'Q.T' 0
β β β ββ Q [id F]
β β ββ x [id N]
β β ββ 0.0 [id S]
β ββ 0.0 [id S]
ββ -0.5 [id T]
dlogp
CGemv{inplace} [id A] '(dx_logprob/dx)' 5
ββ CGemv{inplace} [id B] 4
β ββ AllocEmpty{dtype='float64'} [id C] 3
β β ββ Shape_i{0} [id D] 2
β β ββ Q [id E]
β ββ 1.0 [id F]
β ββ Q [id E]
β ββ Mul [id G] 1
β β ββ [-0.5] [id H]
β β ββ x [id I]
β ββ 0.0 [id J]
ββ -0.5 [id K]
ββ Transpose{axes=[1, 0]} [id L] 'Q.T' 0
β ββ Q [id E]
ββ x [id I]
ββ 1.0 [id F]
160 Β΅s Β± 34.1 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
15.9 Β΅s Β± 2.29 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
For reference: USE_TAU = False before and after (unchanged)
Before:
...
260 Β΅s Β± 13 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
275 Β΅s Β± 19.5 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
After:
...
259 Β΅s Β± 46.3 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
275 Β΅s Β± 30.6 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
Summary
tau used to be ~2x slower logp and ~2x slower dlogp vs direct cov due to the extra MatrixInverse
tau is now ~2x faster logp and ~20x faster dlogp vs direct cov
tau total speedup: ~4x faster logp and ~40x faster dlogp
numpyro tests are failing probably because it now requires the more recent versions of jax. should be fixed by #7407
@aseyboldt any thing that should block this PR?
Looks good. I think it is possible that we could further improve the MvNormal in both parametrizations, but this is definetly an improvement as it is. Most of all I think we should do the same for the CholeskyMvNormal. Looks like we are just computing the cov just to re-compute the cholesky again? At some point we did make use of the cholesky directly, but I guess that got lost in a refactor wtih the pytensor RVs?
Looks good. I think it is possible that we could further improve the MvNormal in both parametrizations, but this is definetly an improvement as it is. Most of all I think we should do the same for the CholeskyMvNormal. Looks like we are just computing the cov just to re-compute the cholesky again? At some point we did make use of the cholesky directly, but I guess that got lost in a refactor wtih the pytensor RVs?
We don't recompute the Cholesky, we have rewrites to remove it and even a specific test for it: https://github.com/pymc-devs/pymc/blob/b407c01accc65a68d17803b6208d6dfdb0e40877/tests/distributions/test_multivariate.py#L2368