linear_operator icon indicating copy to clipboard operation
linear_operator copied to clipboard

[Bug] Unexpected results in `inv_quad_logdet()` after `.add_low_rank()` with dense inputs

Open cwindolf opened this issue 1 year ago • 3 comments

First of all thanks for this incredibly great library, it has been a lifesaver! (and hi @gpleiss from jlg :)

🐛 Bug

When adding A (a DenseLinearOperator) to v (a LowRankRootLinearOperator) the result of inv_quad_logdet() is not what it should be. (Other cases work -- for instance, if A is a DiagLinearOperator, things are fine.)

Below, I've included a test case based on computing normal likelihoods so that I can produce some SciPy numbers for ground truth.

To reproduce

I've broken up the code into a few cases, showing ways to get the right answer and the add_low_rank case which breaks.

Imports...

import torch
import numpy as np
from scipy.stats import multivariate_normal
import linear_operator
from linear_operator import operators

torch.__version__, linear_operator.__version__, np.__version__
# => ('2.5.1.post3', '0.5.3', '1.26.4')

Simulate some test data, and get SciPy log liks:

N = 2 ** 12
D = 8
mean = np.zeros(D)

# test case: normal log likelihoods
# using scipy for reference point
rg = np.random.default_rng(0)

# make a low rank + identity covariance
v = rg.normal(size=(D, 2))
cov = np.eye(D) + v @ v.T

# draws...
y = rg.multivariate_normal(mean=mean, cov=cov, size=N)

# log liks
rv = multivariate_normal(mean=mean, cov=cov)
scipy_lls = rv.logpdf(y)

# convert to torch for below
mean = torch.asarray(mean, dtype=torch.float)
v = torch.asarray(v, dtype=torch.float)
cov = torch.asarray(cov, dtype=torch.float)
y = torch.asarray(y, dtype=torch.float)

# helper function for getting log liks via inv_quad_logdet
log2pi = torch.log(torch.tensor(2 * np.pi))
def ll_via_inv_quad(cov, y):
    inv_quad, logdet = linear_operator.inv_quad_logdet(cov, y.T, logdet=True, reduce_inv_quad=False)
    ll = -0.5 * (inv_quad + logdet + log2pi * y.shape[1])
    return ll

Cases which behave as expected

The iscloses here are True.

# logliks via dense operator
dense_cov = operators.DenseLinearOperator(cov)
dense_lls = ll_via_inv_quad(dense_cov, y)
np.isclose(scipy_lls, dense_lls).all()  # => True

# logliks via diag low rank
diag_eye = operators.DiagLinearOperator(torch.ones(D))
diag_root_cov = operators.LowRankRootAddedDiagLinearOperator(diag_eye, root)
diag_root_lls = ll_via_inv_quad(diag_root_cov, y)
np.isclose(scipy_lls, diag_root_lls).all()  # => True

Failing case

If we use a dense linear operator and land in .add_low_rank(), the isclose() is False here:

# logliks via dense add_low_rank
dense_eye = operators.DenseLinearOperator(torch.eye(D))
root = operators.LowRankRootLinearOperator(v)
dense_root_cov = dense_eye + root
dense_root_lls = ll_via_inv_quad(dense_root_cov, y)
np.isclose(scipy_lls, dense_root_lls).all()  # => False

The differences are substanatial -- the max difference was 876303.6523659548 in this case, and the median abs difference from scipy was 25754.86743231282.

Workaround

I wanted a way to use Woodbury with a dense operator, so I wrote a quick implementation of a LowRankRootSumLinearOperator which is basically identical to LowRankRootAddedDiagLinearOperator -- it makes a Cholesky of the capacitance matrix. My code is here in case it is helpful at all: https://github.com/cwindolf/dartsort/blob/main/src/dartsort/util/more_operators.py

# log liks via alternative to the dense add_low_rank 
from dartsort.util import more_operators
alt_root_cov = more_operators.LowRankRootSumLinearOperator(dense_eye, root)
alt_root_lls = ll_via_inv_quad(alt_root_cov, y)
np.isclose(scipy_lls, alt_root_lls).all()  # => True

Expected Behavior

inv_quad_logdet() should lead to results such that things match scipy in all cases.

System information

Please complete the following information:

  • 0.5.3

  • 2.5.1.post3

  • Mac

cwindolf avatar Dec 01 '24 13:12 cwindolf

For some more info, it seems like some things work and some things don't.

With all vars defined as above, we have:

Things that work:

.solve():

dense_root_solve = dense_root_cov.solve(y.T).T
full_solve = torch.linalg.solve(cov, y.T).T
np.isclose(full_solve, dense_root_solve, atol=1e-5).all(), (full_solve - dense_root_solve).abs().median(), (full_solve - dense_root_solve).abs().max()
# => (True, tensor(1.1921e-07), tensor(1.9073e-06))

linear_operator.inv_quad():

np.isclose(
    linear_operator.inv_quad(cov, y.T, reduce_inv_quad=False),
    linear_operator.inv_quad(dense_root_cov, y.T, reduce_inv_quad=False),
).all()
# => True

Things that don't

.logdet()

dense_root_cov.logdet(), torch.linalg.det(cov).log()
# => (tensor(-13.8775), tensor(3.5932))

The inv_quad part of inv_quad_logdet()

torch.abs(
    linear_operator.inv_quad_logdet(cov, y.T, reduce_inv_quad=False, logdet=True)[0]
    - linear_operator.inv_quad_logdet(dense_root_cov, y.T, reduce_inv_quad=False, logdet=True)[0]
).median()
# => tensor(51524.3984)

So, although regular inv_quad() works, it doesn't as part of inv_quad_logdet(). (The result is the same with logdet=False).

cwindolf avatar Dec 03 '24 16:12 cwindolf

Adding a low rank matrix to a dense diagononally-dominant matrix was unfortunately not the intended use case for add_low_rank. Currently, the default implementation of add_low_rank attempts to update a matrix with an existing low rank decomposition, which is not going to be great when our original matrix is a diagonal matrix.

Can you describe your intended use case in a bit more detail?

cc @saitcakmak

gpleiss avatar Dec 12 '24 18:12 gpleiss

Thanks for your message Geoff! That makes sense.

Yes, my use case is in Gaussian mixture modeling where the $k$th component is $N(\mu_k, C + U_k U_k^T)$. In other words, each component's covariance matrix has a known, dense (or maybe structured somehow, but in any case not low rank), fixed background component $C$ along with a per-component low-rank part $U_k U_k^T$.

What I'm doing now is to use a workaround class which just adapts LowRankRootAddedDiagLinearOperator to the case where the "other" operator is not assumed diagonal (its .solve() is used a couple of times is all). Maybe I'm being silly and should just store the inverse of $C$...

cwindolf avatar Dec 12 '24 19:12 cwindolf