Large negative eigenvalues in covariance matrices with CUDA and torch==1.9.1 in batch mode
🐛 Large negative eigenvalues in covariance matrices with CUDA and torch==1.9.1 in batch mode
The symptom: posterior predictive covariance matrices have very large negative eigenvalues, well beyond what can be cured with jitter.
The conditions: exact GP regression in batch mode on CUDA with torch==1.9.1. Dropping the torch version solves the problem, as does moving to CPU. The problem is not present with a single GP rather than a batch.
To reproduce
import gpytorch, torch
print(gpytorch.__version__)
print(torch.__version__)
# Set cuda=True and batch=True to get the problem.
# Turning either off to make it disappear.
cuda = True
batch = True
# These other flags don't affect the outcome but are there as sanity checks
index_0_batch = False
fixed_noise = True
learn_additional_noise = False
device = "cuda" if cuda else "cpu"
bs = torch.Size([11, ]) if batch else torch.Size([])
train_x = torch.randn(100, device=device, dtype=float)
train_y = torch.randn(100, device=device, dtype=float)
train_y_std = torch.randn(100, device=device, dtype=float)**2 * .1
test_x = torch.randn(200, device=device, dtype=float)
test_y = torch.randn(200, device=device, dtype=float)
class ExactGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean(batch_shape=bs)
self.covar_module = gpytorch.kernels.RBFKernel(batch_shape=bs)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
if fixed_noise:
likelihood = gpytorch.likelihoods.FixedNoiseGaussianLikelihood(noise=train_y_std, batch_shape=bs,
learn_additional_noise=learn_additional_noise).to(device)
else:
likelihood = gpytorch.likelihoods.GaussianLikelihood(
batch_shape=bs).to(device)
model = ExactGPModel(train_x, train_y, likelihood).to(device)
model.train()
likelihood.train()
# Includes GaussianLikelihood parameters
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(1):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
if index_0_batch:
loss = loss[0]
else:
loss = loss.mean(dim=0)
loss.backward()
optimizer.step()
model.eval()
likelihood.eval()
f_preds = model(test_x)
y_preds = likelihood(model(test_x))
f_covar = f_preds.covariance_matrix
print(f"cuda: {cuda}")
print(f"batch mode: {batch}")
print(f"fixed y noise: {fixed_noise}")
if fixed_noise:
print(f"learn additional y noise: {learn_additional_noise}")
print("first batch index of loss" if index_0_batch else "mean loss over batch")
print(f"Min covariance eigenvalue: {torch.linalg.eigvalsh(f_covar).min()}")
1.5.1
1.9.1+cu102
cuda: True
batch mode: False
fixed y noise: True
learn additional y noise: False
mean loss over batch
Min covariance eigenvalue: -173010.1313435591
Expected Behavior
The expected behaviour is just what one gets with torch==1.9.0. That is, any numerical differences in the covariance matrices between CPU and CUDA should be minor rather than the major difference seen here.
System information
- gpytorch==1.5.1
- torch=={1.9.0, 1.9.1}
- Google Colab notebook.
Additional context
I found this while experimenting with Bayesian hyperparameter inference using variational inference. The predictive posterior and MLL require computing intractable integrals against the variational posterior. I'm using MC integration, so each of the batch GPs corresponds to i.i.d. samples from the hyperparameter variational posterior.
The code I've supplied above isn't actually doing this, but I think I've stripped it back to the bare minimum required to produce the numerical problem.
I'm not very familiar with the pytorch code base, but the only differences between 1.9.0 and 1.9.1 that look at all related to linear algebra are some shape checks prior to matrix multiplication. (https://github.com/pytorch/pytorch/compare/v1.9.0...v1.9.1)
Interesting. Large negative eval indeed. Is the CUDA version the same for both torch versions?
Maybe @mruberry or @ngimel have any thoughts what could be causing this discrepancy?
A difference between 1.9 and 1.9.1 seems very odd to me as only a few critical fixes where cherrypicked for 1.9.1. Or is the comparison between 1.8 and 1.9.1?
With 1.9.0 (as confirmed by the print version)
1.5.1
1.9.0+cu102
cuda: True
batch mode: True
fixed y noise: True
learn additional y noise: False
mean loss over batch
Min covariance eigenvalue: -8.396184222906664e-15
And then running pip install torch==1.9.1 and restarting the kernel
1.5.1
1.9.1+cu102
cuda: True
batch mode: True
fixed y noise: True
learn additional y noise: False
mean loss over batch
Min covariance eigenvalue: -1417886.6414530107
So unless I'm missing something, torch 1.9.0 <-> 1.9.1 is the only change?
Do you know whether the eval computations themselves are ok? I.e. what I would suggest is to compute f_covar both in 1.9.0 and in 1.9.1 and save the tensors to disk, and then (i) check what torch.linalg.eigvalsh returns for them, again in both torch versions so that you have 2*2 = 4 combinations. Hopefully that will provide some insight into whether the eval computation itself has issues, or whether the actual computation of covar_f is different.
Good idea. Looks like it's not eigvalsh, but definitely the compuation of f_covar. Here "file version" refers to the version of torch used to compute f_covar and "torch version" is the version used to read in the saved tensor and compute the eigenvalues.
File version: 1.9.0+cu102
torch version: 1.9.0+cu102
Min covariance eigenvalue: -8.077261033961314e-15
File version: 1.9.1+cu102
torch version: 1.9.0+cu102
Min covariance eigenvalue: -511456.5076018272
File version: 1.9.0+cu102
torch version: 1.9.1+cu102
Min covariance eigenvalue: -8.077261033961314e-15
File version: 1.9.1+cu102
torch version: 1.9.1+cu102
Min covariance eigenvalue: -511456.5076018272
How do the state dicts of the model look like when fitted with either torch version? Seems like some of the parameters must be degenerate somehow to result in this.
They look no different. In fact, I can reproduce all of the above even after deleting the training loop (so all raw parameters are 0 in both versions).
I've just had a poke around at the kernel tensors, looking at their eigenvalues and the size of their entires and their inverses' entries. I've calculated the kernel matrices K_train_train, K_test_test and K_train_test directly. Then computing the posterior covariance matrix explicitly I can recreate the above difference between 1.9.0 and 1.9.1. Staring at the matrices though, nothing has jumped out yet as being an important diffrence.
Finally, I just saved a K_train_train to disk. I then loaded it in in 1.9.0 and 1.9.1 and and called torch.linalg.inv on the loaded tensor both as a batch of kernel matrices, and for individual kernel matrices within. With 1.9.1 in batch mode, approx 60% of entries of the inverse kernel are exactly 0. For a single (non batch) kernel in 1.9.1, there are no such entries. Similarly, in 1.9.0, there are no zero entries in batch-mode or not batch-mode. So perhaps there is an issue with batch matrix inversion in 1.9.1?
Interesting. There should be no explicit inverses anywhere on the code though (there better not be!). Does the same happen when you use cholesky + cholesky_solce instead of the inverse?
Can you fix random seeds to make inputs deterministic, and compare f_covar values produced by 1.9.0 and 1.9.1 (without the training loop, as it narrows down things that are going wrong)?
The problem is in torch.cholesky_solve.
Modifying GPyTorch source at models/exact_prediction_strategy.py here "fixes" the problem.
- covar_correction_rhs = train_train_covar.inv_matmul(train_test_covar)
+ covar_correction_rhs = train_train_covar.to('cpu').inv_matmul(train_test_covar.to('cpu')).to('cuda')
Here's a minimal reproducer of the problem (verified with pip installed PyTorch 1.9.1+cu102 in Colab):
import torch
print(torch.__version__) # 1.9.1+cu102
torch.manual_seed(0)
A = torch.randn(11, 100, 100, device='cuda')
A = A @ A.transpose(-2, -1)
b = torch.randn(11, 100, 200, device='cuda')
print(torch.allclose(torch.cholesky_solve(b, A), torch.cholesky_solve(b.cpu(), A.cpu()).cuda())) # False
print(torch.allclose(torch.cholesky_solve(b[0], A[0]), torch.cholesky_solve(b[0].cpu(), A[0].cpu()).cuda())) # True
I built locally master branch of PyTorch with MAGMA and CUDA 11.3.1 and both allclose evaluate to True.
Thanks for debugging @IvanYashchuk, would you file a more detailed issue about what's causing the problem in the PyTorch GitHub so we can review it there?