pymc icon indicating copy to clipboard operation
pymc copied to clipboard

Implement specialized MvNormal density based on precision matrix

Open ricardoV94 opened this issue 1 year ago β€’ 5 comments

Description

This PR is exploring a specialized logp for a MvNormal (and possible MvStudentT) parametrized directly in terms of tau. According to common model implementation looks like:

import pymc as pm
import numpy as np

A = np.array([
    [0, 1, 1],
    [1, 0, 1], 
    [1, 1, 0]
])
D = A.sum(axis=-1)
np.testing.assert_allclose(A, A.T), "should be symmetric"

with pm.Model() as m:
    tau = pm.InverseGamma("tau", 1, 1)
    alpha = pm.Beta("alpha", 10, 10)
    Q = tau * (D - alpha * A)
    y = pm.MvNormal("y", mu=np.zeros(3), tau=Q)

TODO (some are optional for this PR)

  • [x] Benchmark to confirm it's more performant than old form where there's a matrix inverse and a cholesky decomposition (now we have a slogdet that does LU factorization under the hood)
  • [x] Benchmark gradients as well
  • [x] Check whether CAR is redundant with this?
  • [x] Add tests
  • [ ] Consider extending to MvStudentT
  • [x] Explore whether something based on Cholesky decomposition still makes (more) sense here. CC @aseyboldt
  • [x] ~~Sparse implementation? May need some ideas like: https://stackoverflow.com/questions/19107617/how-to-compute-scipy-sparse-matrix-determinant-without-turning-it-to-dense~~ Investigate in a follow up PR

Related Issue

  • [ ] Related to https://github.com/pymc-devs/pymc-experimental/issues/340

Checklist

Type of change

  • [x] New feature / enhancement
  • [ ] Bug fix
  • [ ] Documentation
  • [ ] Maintenance
  • [ ] Other (please specify):

CC @theorashid @elizavetasemenova


πŸ“š Documentation preview πŸ“š: https://pymc--7345.org.readthedocs.build/en/7345/

ricardoV94 avatar Jun 03 '24 11:06 ricardoV94

Implementation checks may fail until https://github.com/pymc-devs/pytensor/issues/799 is fixed

ricardoV94 avatar Jun 03 '24 11:06 ricardoV94

Check out this pull request onΒ  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Benchmark code

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt

rng = np.random.default_rng(123)

n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)

with pm.Model(check_bounds=False) as m:
    Q = pm.Data("Q", Q_test)
    x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)

logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)


dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)

np.testing.assert_allclose(logp_fn(x_test), np.array(-1789.93662205))

np.testing.assert_allclose(np.sum(dlogp_fn(x_test) ** 2), np.array(18445204.8755109), rtol=1e-6)

# Before: 2.66 ms
# After: 1.31 ms
%timeit -n 1000 logp_fn(x_test)

# Before: 2.45 ms
# After: 72 Β΅s
%timeit -n 1000 dlogp_fn(x_test)

ricardoV94 avatar Jun 21 '24 15:06 ricardoV94

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 92.20%. Comparing base (8eaa9be) to head (8550f01). Report is 100 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7345      +/-   ##
==========================================
+ Coverage   92.18%   92.20%   +0.01%     
==========================================
  Files         103      103              
  Lines       17263    17301      +38     
==========================================
+ Hits        15914    15952      +38     
  Misses       1349     1349              
Files with missing lines Coverage Ξ”
pymc/distributions/multivariate.py 93.10% <100.00%> (+0.25%) :arrow_up:
pymc/logprob/rewriting.py 89.18% <100.00%> (+0.17%) :arrow_up:

codecov[bot] avatar Jun 25 '24 13:06 codecov[bot]

Final question is just whether we want / can do a similar thing for the MvStudentT. Otherwise it's ready to merge on my end

CC @elizavetasemenova

ricardoV94 avatar Jun 25 '24 13:06 ricardoV94

Last benchmarks, running the following script:

%env OMP_NUM_THREADS=1

USE_TAU = True

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt

rng = np.random.default_rng(123)

n = 100
L = rng.uniform(low=0.1, high=1.0, size=(n,n))
Sigma = L @ L.T
Q_test = np.linalg.inv(Sigma)
x_test = rng.normal(size=n)

with pm.Model(check_bounds=False) as m:
    Q = pm.Data("Q", Q_test)
    if USE_TAU:
        x = pm.MvNormal("x", mu=pt.zeros(n), tau=Q)
    else:
        x = pm.MvNormal("x", mu=pt.zeros(n), cov=Q)

logp_fn = m.compile_logp().f
logp_fn.trust_input=True
print("logp")
pytensor.dprint(logp_fn)

dlogp_fn = m.compile_dlogp().f
dlogp_fn.trust_input=True
print("dlogp")
pytensor.dprint(dlogp_fn)

%timeit -n 10000 logp_fn(x_test)
%timeit -n 10000 dlogp_fn(x_test)

USE_TAU = TRUE, without optimization:

logp
Composite{((i2 - (i0 * i1)) - i3)} [id A] 'x_logprob' 9
 β”œβ”€ 0.5 [id B]
 β”œβ”€ DropDims{axis=0} [id C] 8
 β”‚  └─ CAReduce{Composite{(i0 + sqr(i1))}, axis=1} [id D] 7
 β”‚     └─ Transpose{axes=[1, 0]} [id E] 5
 β”‚        └─ SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2} [id F] 3
 β”‚           β”œβ”€ Cholesky{lower=True, destructive=False, on_error='nan'} [id G] 2
 β”‚           β”‚  └─ MatrixInverse [id H] 1
 β”‚           β”‚     └─ Q [id I]
 β”‚           └─ ExpandDims{axis=1} [id J] 0
 β”‚              └─ x [id K]
 β”œβ”€ -91.89385332046727 [id L]
 └─ CAReduce{Composite{(i0 + log(i1))}, axes=None} [id M] 6
    └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id N] 4
       └─ Cholesky{lower=True, destructive=False, on_error='nan'} [id G] 2
          └─ Β·Β·Β·

dlogp
DropDims{axis=1} [id A] '(dx_logprob/dx)' 7
 └─ SolveTriangular{trans=0, unit_diagonal=False, lower=False, check_finite=True, b_ndim=2} [id B] 6
    β”œβ”€ Transpose{axes=[1, 0]} [id C] 5
    β”‚  └─ Cholesky{lower=True, destructive=False, on_error='nan'} [id D] 2
    β”‚     └─ MatrixInverse [id E] 1
    β”‚        └─ Q [id F]
    └─ Neg [id G] 4
       └─ SolveTriangular{trans=0, unit_diagonal=False, lower=True, check_finite=True, b_ndim=2} [id H] 3
          β”œβ”€ Cholesky{lower=True, destructive=False, on_error='nan'} [id D] 2
          β”‚  └─ Β·Β·Β·
          └─ ExpandDims{axis=1} [id I] 0
             └─ x [id J]

541 Β΅s Β± 56.8 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
503 Β΅s Β± 41.3 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)

USE_TAU = True with optimization

logp
Composite{(i4 * ((i2 - (i0 * i1)) + i3))} [id A] 'x_logprob' 11
 β”œβ”€ 2.0 [id B]
 β”œβ”€ CAReduce{Composite{(i0 + log(i1))}, axes=None} [id C] 10
 β”‚  └─ ExtractDiag{offset=0, axis1=0, axis2=1, view=False} [id D] 9
 β”‚     └─ Cholesky{lower=True, destructive=False, on_error='raise'} [id E] 8
 β”‚        └─ Q [id F]
 β”œβ”€ 183.78770664093454 [id G]
 β”œβ”€ DropDims{axis=0} [id H] 7
 β”‚  └─ CGemv{inplace} [id I] 6
 β”‚     β”œβ”€ AllocEmpty{dtype='float64'} [id J] 5
 β”‚     β”‚  └─ 1 [id K]
 β”‚     β”œβ”€ 1.0 [id L]
 β”‚     β”œβ”€ ExpandDims{axis=0} [id M] 4
 β”‚     β”‚  └─ x [id N]
 β”‚     β”œβ”€ CGemv{inplace} [id O] 3
 β”‚     β”‚  β”œβ”€ AllocEmpty{dtype='float64'} [id P] 2
 β”‚     β”‚  β”‚  └─ Shape_i{1} [id Q] 1
 β”‚     β”‚  β”‚     └─ Q [id F]
 β”‚     β”‚  β”œβ”€ 1.0 [id L]
 β”‚     β”‚  β”œβ”€ Transpose{axes=[1, 0]} [id R] 'Q.T' 0
 β”‚     β”‚  β”‚  └─ Q [id F]
 β”‚     β”‚  β”œβ”€ x [id N]
 β”‚     β”‚  └─ 0.0 [id S]
 β”‚     └─ 0.0 [id S]
 └─ -0.5 [id T]

dlogp
CGemv{inplace} [id A] '(dx_logprob/dx)' 5
 β”œβ”€ CGemv{inplace} [id B] 4
 β”‚  β”œβ”€ AllocEmpty{dtype='float64'} [id C] 3
 β”‚  β”‚  └─ Shape_i{0} [id D] 2
 β”‚  β”‚     └─ Q [id E]
 β”‚  β”œβ”€ 1.0 [id F]
 β”‚  β”œβ”€ Q [id E]
 β”‚  β”œβ”€ Mul [id G] 1
 β”‚  β”‚  β”œβ”€ [-0.5] [id H]
 β”‚  β”‚  └─ x [id I]
 β”‚  └─ 0.0 [id J]
 β”œβ”€ -0.5 [id K]
 β”œβ”€ Transpose{axes=[1, 0]} [id L] 'Q.T' 0
 β”‚  └─ Q [id E]
 β”œβ”€ x [id I]
 └─ 1.0 [id F]

160 Β΅s Β± 34.1 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
15.9 Β΅s Β± 2.29 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)

For reference: USE_TAU = False before and after (unchanged)

Before:

...
260 Β΅s Β± 13 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
275 Β΅s Β± 19.5 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)

After:

...
259 Β΅s Β± 46.3 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)
275 Β΅s Β± 30.6 Β΅s per loop (mean Β± std. dev. of 7 runs, 10,000 loops each)

Summary

tau used to be ~2x slower logp and ~2x slower dlogp vs direct cov due to the extra MatrixInverse tau is now ~2x faster logp and ~20x faster dlogp vs direct cov tau total speedup: ~4x faster logp and ~40x faster dlogp

ricardoV94 avatar Jul 12 '24 14:07 ricardoV94

numpyro tests are failing probably because it now requires the more recent versions of jax. should be fixed by #7407

ricardoV94 avatar Jul 12 '24 16:07 ricardoV94

@aseyboldt any thing that should block this PR?

ricardoV94 avatar Jul 31 '24 14:07 ricardoV94

Looks good. I think it is possible that we could further improve the MvNormal in both parametrizations, but this is definetly an improvement as it is. Most of all I think we should do the same for the CholeskyMvNormal. Looks like we are just computing the cov just to re-compute the cholesky again? At some point we did make use of the cholesky directly, but I guess that got lost in a refactor wtih the pytensor RVs?

aseyboldt avatar Aug 02 '24 21:08 aseyboldt

Looks good. I think it is possible that we could further improve the MvNormal in both parametrizations, but this is definetly an improvement as it is. Most of all I think we should do the same for the CholeskyMvNormal. Looks like we are just computing the cov just to re-compute the cholesky again? At some point we did make use of the cholesky directly, but I guess that got lost in a refactor wtih the pytensor RVs?

We don't recompute the Cholesky, we have rewrites to remove it and even a specific test for it: https://github.com/pymc-devs/pymc/blob/b407c01accc65a68d17803b6208d6dfdb0e40877/tests/distributions/test_multivariate.py#L2368

ricardoV94 avatar Aug 03 '24 02:08 ricardoV94