BUG: PrecisionMvNormalRV missing factor of 1/2 for log determinant in logp function
Describe the issue:
At this line,
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet
should be
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form + logdet)
Reproduceable code example:
N/A
Error message:
PyMC version information:
5.21.1
Context for the issue:
No response
@ckrapu wanna open a PR by any chance?
@ckrapu wanna open a PR by any chance?
This issue seems inactive so I'm happy to take it over if available
Sorry, @ricardoV94 , am I missing something or should this be a won't fix ?
Consider, $X \sim MVN(\mu, T)$ where $T$ is the precision matrix
We know : $\log(p) = - \frac{1}{2} (k*\log(2\pi) + (x - \mu)^\top T (x - \mu)) + \frac{1}{2}\log(\det{T})$
Since, $\log(\det{L}) = \frac{1}{2}\log(\det{T})$
Thus, $\log(p) = - \frac{1}{2} (k*\log(2\pi) + (x - \mu)^\top T (x - \mu)) + \log(\det{L})$
This matches the code:
logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau))
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet
There is also this test which checks for the logp and fails when the suggested change is made
I also wrote a small test to check code paths. After making the change here to
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form + logdet)
the following test
def test_logp_random_tau_matches_cov():
n = 5
A = np.random.randn(n, n)
tau = np.dot(A, A.T) + np.eye(n) * 0.01 # ensure it's well-conditioned
cov = np.linalg.inv(tau)
mu = np.random.randn(n)
vals = np.random.randn(n)
logp_cov = pm.logp(pm.MvNormal.dist(mu=mu, cov=cov), vals).eval()
# logp_tau = pm.logp(pm.MvNormal.dist(mu=mu, tau=tau), vals).eval()
with Model() as m:
Q = pm.Flat("Q", shape=(n, n))
y = pm.MvNormal("y", mu=mu, tau=Q)
y_logp_fn = m.compile_logp(vars=[y]).f
logp_tau = y_logp_fn(y=vals, Q=tau)
np.testing.assert_allclose(logp_cov, logp_tau) # <<< assert 1 >>>
np.testing.assert_allclose(logp_tau, st.multivariate_normal.logpdf(vals, mu, cov)) # <<< assert 2 >>>
np.testing.assert_allclose(logp_cov, st.multivariate_normal.logpdf(vals, mu, cov)) # <<< assert 3 >>>
is failing at assert 1 & 2 but not 3. This makes me think that we shouldn't make this change as the present formula is correct.
Happy to be corrected. Please let me know 🙏
We tested against the regular logpdf so you're probably right that our implementation is correct @asifzubair. I didn't try to follow the math, it's late here :).
CC'ing @ckrapu just in case
Great, thanks, @ricardoV94 . I'll stop looking at this for now, but will come back to it if there is activity on the thread. Thanks, again! 🙏
Thanks for checking! I'll tentatively close this