gpytorch
gpytorch copied to clipboard
Unable to achieve convergence on both GPU and CPU
Question Description
I am attempting to fit an exact GP regression on a dataset of ~ 10000 points. The train_x
is 3×7740×2
(repeated from the base shape 7740×2
) and the train_y
is 3×7740
where 3
means the batch shape. Specifically, the input consists of 2-dimensional plane positions XY with values ranging from -14 to -18 as decimals. The output is normalized RGB colors with 3 dimensions, ranging from 0 to 1. The three regression tasks from XY to RGB are independent of each other.
When following the Batch GP Regression tutorial:
Training on the CPU:
The code does not throw any errors, but it fails to converge and slows down as it runs. However, when I multiply the input train_x
by 100, the Batch GP converges quickly and performs well.
Training on the GPU:
The following errors may occur, including NaN loss
and NumericalWarning CG terminated
. I have tried to multiply the input train_x
by 100 or normalize the input data using Min-Max Scaling, but they didn't work. When I set the data and model to be double precision, the NaN loss
disappeared, but it became very slow (of course for double precision) and couldn't converge to the right position.
Is there an issue with my input data? It appears that this is indicative of numerical instability in the numerical computations. I guess something went wrong in computing the log likelihood here
preconditioner, precond_lt, logdet_p = self._preconditioner()
if precond_lt is None:
from ..operators.identity_linear_operator import IdentityLinearOperator
precond_lt = IdentityLinearOperator(
diag_shape=self.size(-1),
batch_shape=self.batch_shape,
dtype=self.dtype,
device=self.device,
)
logdet_p = 0.0
precond_args = precond_lt.representation()
probe_vectors, probe_vector_norms = self._probe_vectors_and_norms()
func = InvQuadLogdet.apply
inv_quad_term, pinvk_logdet = func(
self.representation_tree(),
precond_lt.representation_tree(),
preconditioner,
len(precond_args),
(inv_quad_rhs is not None),
probe_vectors,
probe_vector_norms,
*(list(args) + list(precond_args)),
)
logdet_term = pinvk_logdet
logdet_term = logdet_term + logdet_p
The data can be downloaded from the attached .zip file. data.zip
Thanks in advance !
Here are the details about the data, code and error on GPU.
Data Example
train_x train_y
tensor([-15.0223, -14.0026]) tensor([0.5451, 0.2588, 0.3765])
tensor([-16.1318, -14.1548]) tensor([0.5882, 0.3686, 0.4667])
tensor([-16.7716, -14.5253]) tensor([0.6078, 0.3882, 0.4863])
tensor([-15.9107, -14.8165]) tensor([0.5647, 0.3294, 0.4314])
tensor([-14.9211, -15.1249]) tensor([0.6784, 0.4431, 0.5412])
tensor([-14.0937, -15.4103]) tensor([0.6392, 0.3333, 0.3882])
tensor([-17.4703, -15.1177]) tensor([0.6549, 0.4275, 0.5216])
tensor([-15.7863, -17.1730]) tensor([0.8549, 0.5882, 0.6863])
tensor([-15.3097, -17.5722]) tensor([0.8353, 0.5686, 0.6510])
tensor([-14.6760, -17.8014]) tensor([0.8392, 0.5333, 0.6196])
code
....
# train data torch.Size([3, 7740, 2]) torch.Size([3, 7740])
# batch shape 3
batch_shape = train_y.shape[-1]
train_x = train_x.unsqueeze(0).repeat(batch_shape, 1, 1).to(device)
train_y = train_y.transpose(0, 1).to(device)
class BatchGPModel(gpytorch.models.ExactGP):
def __init__(self, train_inputs, train_targets, likelihood, batch_shape, use_ard=False):
super(BatchGPModel, self).__init__(train_inputs, train_targets, likelihood)
ard_num_dims = train_inputs.shape[-1] if use_ard else None
self.shape = torch.Size([batch_shape])
self.mean_module = gpytorch.means.ConstantMean(batch_shape=self.shape, constant_constraint=gpytorch.constraints.Interval(0.0, 1.0))
self.base_kernel = gpytorch.kernels.RBFKernel(batch_shape=self.shape, ard_num_dims=ard_num_dims)
self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel, batch_shape=self.shape)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
likelihood = gpytorch.likelihoods.GaussianLikelihood(batch_shape=torch.Size([batch_shape])).to(device)
model = BatchGPModel(train_x, train_y, likelihood, batch_shape=batch_shape, use_ard=True).to(device)
training_iter = 100
model.train()
likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# Loss for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iter):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y).sum()
loss.backward()
optimizer.step()
print('Iter %d/%d - Loss: %.3f mean0: %.3f mean1: %.3f mean2: %.3f noise0: %.3f noise1: %.3f noise2: %.3f ' % (
i + 1, training_iter, loss.item(),
model.mean_module.constant[0].item(),
model.mean_module.constant[1].item(),
model.mean_module.constant[2].item(),
model.likelihood.noise[0].item(),
model.likelihood.noise[1].item(),
model.likelihood.noise[2].item()
))
Error Message
Iter 1/50 - Loss: 2.251 mean0: 0.574 mean1: 0.426 mean2: 0.574 noise0: 0.554 noise1: 0.554 noise2: 0.554
Iter 2/50 - Loss: 1.910 mean0: 0.643 mean1: 0.415 mean2: 0.583 noise0: 0.437 noise1: 0.437 noise2: 0.437
Iter 3/50 - Loss: 1.553 mean0: 0.700 mean1: 0.436 mean2: 0.559 noise0: 0.341 noise1: 0.341 noise2: 0.341
Iter 4/50 - Loss: 1.179 mean0: 0.744 mean1: 0.469 mean2: 0.529 noise0: 0.263 noise1: 0.263 noise2: 0.263
Iter 5/50 - Loss: 0.792 mean0: 0.775 mean1: 0.502 mean2: 0.515 noise0: 0.201 noise1: 0.201 noise2: 0.201
Iter 6/50 - Loss: 0.395 mean0: 0.793 mean1: 0.517 mean2: 0.524 noise0: 0.152 noise1: 0.152 noise2: 0.152
Iter 7/50 - Loss: nan mean0: 0.799 mean1: 0.506 mean2: 0.544 noise0: 0.114 noise1: 0.114 noise2: 0.114
Iter 8/50 - Loss: nan mean0: 0.795 mean1: 0.483 mean2: 0.559 noise0: 0.086 noise1: 0.086 noise2: 0.086
Iter 9/50 - Loss: nan mean0: 0.780 mean1: 0.462 mean2: 0.555 noise0: 0.064 noise1: 0.064 noise2: 0.064
Iter 10/50 - Loss: nan mean0: 0.756 mean1: 0.461 mean2: 0.538 noise0: 0.047 noise1: 0.048 noise2: 0.047
Iter 11/50 - Loss: nan mean0: 0.723 mean1: 0.481 mean2: 0.522 noise0: 0.035 noise1: 0.035 noise2: 0.035
Iter 12/50 - Loss: nan mean0: 0.681 mean1: 0.511 mean2: 0.525 noise0: 0.026 noise1: 0.026 noise2: 0.026
Iter 13/50 - Loss: nan mean0: 0.633 mean1: 0.532 mean2: 0.546 noise0: 0.019 noise1: 0.019 noise2: 0.019
Iter 14/50 - Loss: nan mean0: 0.586 mean1: 0.528 mean2: 0.570 noise0: 0.014 noise1: 0.014 noise2: 0.014
Iter 15/50 - Loss: nan mean0: 0.551 mean1: 0.510 mean2: 0.578 noise0: 0.011 noise1: 0.011 noise2: 0.011
Iter 16/50 - Loss: nan mean0: 0.534 mean1: 0.504 mean2: 0.565 noise0: 0.008 noise1: 0.008 noise2: 0.008
Iter 17/50 - Loss: nan mean0: 0.537 mean1: 0.498 mean2: 0.571 noise0: 0.006 noise1: 0.006 noise2: 0.006
[/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:337](https://file+.vscode-resource.vscode-cdn.net/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:337): NumericalWarning: CG terminated in 1000 iterations with average residual norm 783129706496.0 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
warnings.warn(
Iter 18/50 - Loss: nan mean0: 0.553 mean1: 0.536 mean2: 0.567 noise0: 0.005 noise1: 0.007 noise2: 0.005
[/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:337](https://file+.vscode-resource.vscode-cdn.net/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:337): NumericalWarning: CG terminated in 1000 iterations with average residual norm nan which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
warnings.warn(
[/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/operators/added_diag_linear_operator.py:128](https://file+.vscode-resource.vscode-cdn.net/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/operators/added_diag_linear_operator.py:128): NumericalWarning: NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.
warnings.warn(
Iter 19/50 - Loss: nan mean0: 0.590 mean1: 0.500 mean2: nan noise0: 0.005 noise1: 0.009 noise2: nan
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[8], line 21
17 optimizer.zero_grad()
19 output = model(train_x)
---> 21 loss = -mll(output, train_y).sum()
23 loss.backward()
25 optimizer.step()
File [~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/module.py:31](https://file+.vscode-resource.vscode-cdn.net/home/dell/Codes/paper/second/Gpytorch/~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/module.py:31), in Module.__call__(self, *inputs, **kwargs)
30 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 31 outputs = self.forward(*inputs, **kwargs)
32 if isinstance(outputs, list):
33 return [_validate_module_outputs(output) for output in outputs]
File [~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64](https://file+.vscode-resource.vscode-cdn.net/home/dell/Codes/paper/second/Gpytorch/~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64), in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params)
62 # Get the log prob of the marginal distribution
63 output = self.likelihood(function_dist, *params) # input prior: p(f|X) and output: p(y|X)
---> 64 res = output.log_prob(target) # log p(y|X)
65 res = self._add_other_terms(res, params)
67 # Scale by the amount of data we have
File [~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/distributions/multivariate_normal.py:195](https://file+.vscode-resource.vscode-cdn.net/home/dell/Codes/paper/second/Gpytorch/~/anaconda3/envs/gpytorch/lib/python3.8/site-packages/gpytorch/distributions/multivariate_normal.py:195), in MultivariateNormal.log_prob(self, value)
191 # Get log determininant and first part of quadratic form
192 # inv_quad = (K+\sigma^2 I)^{-1}
...
201 # Sometime we're lucky and the preconditioner solves the system right away
202 # Check for convergence
203 residual_norm = residual.norm(2, dim=-2, keepdim=True)
RuntimeError: NaNs encountered when trying to perform matrix-vector multiplication
Have you tried using a smaller learning rate? I don't think that this is a bug in GPyTorch, since we are using completely identical code for CPU and GPU.
Thank you for your rely.
I have tried smaller learning rates, starting from 0.3 and gradually decreasing to 0.01, but it still results in NaN values. As the learning rate decreases, the model fails to learn anything.
I simplified my question. Now we are using completely identical code for CPU and GPU. I only added the code
if torch.cuda.is_available():
train_x = train_x.cuda()
train_y = train_y.cuda()
model = model.cuda()
likelihood = likelihood.cuda()
The code runs fine on CPU, but on GPU, it throws a NumericalWarning: CG terminated. Doesn't this mean that GPyTorch has a bug?
This is a simple sample file. I hope you have time to take a look at it. test.zip
Many thanks !
I'm not going to open up your zip sample file. If you can post a small reproducible example in the chat here, then I will take a look.
Here is a simple example (using GP for super-resolution). I use the pixel coordinates XY of the image as train_x
and RGB values as train_y
. Everything works fine when I remove the .cuda()
code.
# ori_image 60 × 60 resolution
image_tensor = transforms.ToTensor()(ori_image)
image_tensor = image_tensor.unsqueeze(0)
b, _, h, w = image_tensor.shape
x = np.arange(w)*2
y = np.arange(h)*2
X, Y = np.meshgrid(x, y)
sample_x = torch.from_numpy(np.stack([X, Y], axis=-1).reshape(-1, 2))
sample_img = image_tensor.squeeze(0).reshape(3, -1).transpose(0, 1)
batch_shape = sample_img.shape[-1]
train_x = sample_x.unsqueeze(0).repeat((batch_shape, 1, 1))
train_y = sample_img.transpose(0, 1)
# [3, 3600, 2], [3, 3600]
print('train_x shape:', train_x.shape, 'train_y shape:', train_y.shape)
class BatchGPModel(gpytorch.models.ExactGP):
def __init__(self, train_inputs, train_targets, likelihood, batch_shape, use_ard=False):
super(BatchGPModel, self).__init__(train_inputs, train_targets, likelihood)
ard_num_dims = train_inputs.shape[-1] if use_ard else None
self.shape = torch.Size([batch_shape])
self.mean_module = gpytorch.means.ConstantMean(batch_shape=self.shape, constant_constraint=gpytorch.constraints.Interval(0.0, 1.0))
self.base_kernel = gpytorch.kernels.RBFKernel(batch_shape=self.shape, ard_num_dims=ard_num_dims)
self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel, batch_shape=self.shape)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
# initialize the likelihood and prior, batch shape depends on the dimension of y (e.g. RGB image has 3 channels)
likelihood = gpytorch.likelihoods.GaussianLikelihood(batch_shape=torch.Size([batch_shape]))
model = BatchGPModel(train_x, train_y, likelihood, batch_shape=batch_shape, use_ard=True)
if torch.cuda.is_available():
train_x = train_x.cuda()
train_y = train_y.cuda()
model = model.cuda()
likelihood = likelihood.cuda()
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(50):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y).sum()
loss.backward()
optimizer.step()
print('Iter %d/%d - Loss: %.3f mean0: %.3f mean1: %.3f mean2: %.3f noise0: %.3f noise1: %.3f noise2: %.3f' % (
i + 1, 50, loss.item(),
model.mean_module.constant[0].item(),
model.mean_module.constant[1].item(),
model.mean_module.constant[2].item(),
model.likelihood.noise[0].item(),
model.likelihood.noise[1].item(),
model.likelihood.noise[2].item()
))
Otherwise, I get the following error:
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 81.5545425415039 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
warnings.warn(
Iter 1/50 - Loss: 2.826 mean0: 0.525 mean1: 0.475 mean2: 0.475 noise0: 0.644 noise1: 0.744 noise2: 0.744
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 14.917732238769531 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
warnings.warn(
Iter 2/50 - Loss: 2.931 mean0: 0.550 mean1: 0.469 mean2: 0.466 noise0: 0.607 noise1: 0.779 noise2: 0.780
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 89.217041015625 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
warnings.warn(
Iter 3/50 - Loss: 2.931 mean0: 0.574 mean1: 0.475 mean2: 0.465 noise0: 0.581 noise1: 0.772 noise2: 0.809
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 456.7450866699219 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
warnings.warn(
Iter 4/50 - Loss: 2.898 mean0: 0.598 mean1: 0.486 mean2: 0.469 noise0: 0.553 noise1: 0.768 noise2: 0.832
/home/dell/anaconda3/envs/gpytorch/lib/python3.8/site-packages/linear_operator/utils/linear_cg.py:338: NumericalWarning: CG terminated in 1000 iterations with average residual norm 36.63666915893555 which is larger than the tolerance of 1 specified by linear_operator.settings.cg_tolerance. If performance is affected, consider raising the maximum number of CG iterations by running code in a linear_operator.settings.max_cg_iterations(value) context.
warnings.warn(
Iter 5/50 - Loss: 2.851 mean0: 0.621 mean1: 0.497 mean2: 0.476 noise0: 0.527 noise1: 0.765 noise2: 0.853