gpytorch
gpytorch copied to clipboard
[Bug] Type error in multitask, large input for GridInterpolationKernel
🐛 Bug
I was trying to extend MultiTaskKernel's example by including GridInterpolationKernel, with the input being more than one dimension. However when I'm using all three features (multitask, grid interpolation kernel, multi-dimensional input), I get error in internal of library.
To reproduce
** Code snippet to reproduce **
import torch
import gpytorch
train_x = torch.cartesian_prod(torch.linspace(-1, 1, 50), torch.linspace(-1, 1, 50))
train_y = torch.stack([
torch.sin(train_x[:,0] * (2 * math.pi)) * torch.sin(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
torch.sin(train_x[:,0] * (2 * math.pi)) * torch.cos(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
torch.cos(train_x[:,0] * (2 * math.pi)) * torch.sin(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
torch.cos(train_x[:,0] * (2 * math.pi)) * torch.cos(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
], -1)
class MultitaskGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.MultitaskMean(
gpytorch.means.ConstantMean(), num_tasks=4
)
self.covar_module = gpytorch.kernels.MultitaskKernel(
gpytorch.kernels.GridInterpolationKernel(
gpytorch.kernels.RBFKernel(),
grid_size=10, num_dims=2
), num_tasks=4, rank=0
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=4)
model = MultitaskGPModel(train_x, train_y, likelihood)
# this is for running the notebook in our testing framework
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(2):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
optimizer.step()
** Stack trace/error message **
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_13748/1593994832.py in <module>
46 optimizer.zero_grad()
47 output = model(train_x)
---> 48 loss = -mll(output, train_y)
49 loss.backward()
50 print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\module.py in __call__(self, *inputs, **kwargs)
28
29 def __call__(self, *inputs, **kwargs):
---> 30 outputs = self.forward(*inputs, **kwargs)
31 if isinstance(outputs, list):
32 return [_validate_module_outputs(output) for output in outputs]
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\mlls\exact_marginal_log_likelihood.py in forward(self, function_dist, target, *params)
60 # Get the log prob of the marginal distribution
61 output = self.likelihood(function_dist, *params)
---> 62 res = output.log_prob(target)
63 res = self._add_other_terms(res, params)
64
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\distributions\multitask_multivariate_normal.py in log_prob(self, value)
209 new_shape = value.shape[:-2] + value.shape[:-3:-1]
210 value = value.view(new_shape).transpose(-1, -2).contiguous()
--> 211 return super().log_prob(value.view(*value.shape[:-2], -1))
212
213 @property
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\distributions\multivariate_normal.py in log_prob(self, value)
167 # Get log determininant and first part of quadratic form
168 covar = covar.evaluate_kernel()
--> 169 inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
170
171 res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_added_diag_lazy_tensor.py in inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad)
60 def inv_quad_logdet(self, inv_quad_rhs=None, logdet=False, reduce_inv_quad=True):
61 if inv_quad_rhs is not None:
---> 62 inv_quad_term, _ = super().inv_quad_logdet(
63 inv_quad_rhs=inv_quad_rhs, logdet=False, reduce_inv_quad=reduce_inv_quad
64 )
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad)
1332 func = InvQuadLogDet.apply
1333
-> 1334 inv_quad_term, logdet_term = func(
1335 self.representation_tree(),
1336 self.dtype,
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\functions\_inv_quad_log_det.py in forward(ctx, representation_tree, dtype, device, matrix_shape, batch_shape, inv_quad, logdet, probe_vectors, probe_vector_norms, *args)
158
159 else:
--> 160 solves = lazy_tsr._solve(rhs, preconditioner, num_tridiag=0)
161
162 # Final values to return
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_added_diag_lazy_tensor.py in _solve(self, rhs, preconditioner, num_tridiag)
187 # https://papers.nips.cc/paper/2013/file/59c33016884a62116be975a9bb8257e3-Paper.pdf
188
--> 189 dlt_inv_root, evals_p_i, evecs = _symmetrize_kpadlt_constructor(lt, dlt)
190
191 res1 = evecs._transpose_nonbatch().matmul(dlt_inv_root.matmul(rhs))
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_added_diag_lazy_tensor.py in _symmetrize_kpadlt_constructor(lt, dlt)
37 *[d.matmul(k).matmul(d) for k, d in zip(lt.lazy_tensors, dlt_inv_root.lazy_tensors)]
38 )
---> 39 evals, evecs = symm_prod.diagonalization()
40 evals_plus_i = DiagLazyTensor(evals + 1.0)
41
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_lazy_tensor.py in diagonalization(self, method)
142 if method is None:
143 method = "symeig"
--> 144 return super().diagonalization(method=method)
145
146 @cached
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
61
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in diagonalization(self, method)
1634
1635 elif method == "symeig":
-> 1636 evals, evecs = self.symeig(eigenvectors=True)
1637 else:
1638 raise RuntimeError(f"Unknown diagonalization method '{method}'")
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
61
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in symeig(self, eigenvectors)
1911 except CachingError:
1912 pass
-> 1913 return self._symeig(eigenvectors=eigenvectors)
1914
1915 def to(self, *args, **kwargs):
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\kronecker_product_lazy_tensor.py in _symeig(self, eigenvectors, return_evals_as_lazy)
292 evals, evecs = [], []
293 for lt in self.lazy_tensors:
--> 294 evals_, evecs_ = lt.symeig(eigenvectors=eigenvectors)
295 evals.append(evals_)
296 evecs.append(evecs_)
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
61
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in symeig(self, eigenvectors)
1911 except CachingError:
1912 pass
-> 1913 return self._symeig(eigenvectors=eigenvectors)
1914
1915 def to(self, *args, **kwargs):
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in _symeig(self, eigenvectors)
2243 # potentially perform decomposition in double precision for numerical stability
2244 dtype = self.dtype
-> 2245 evals, evecs = torch.linalg.eigh(self.evaluate().to(dtype=settings._linalg_dtype_symeig.value()))
2246 # chop any negative eigenvalues.
2247 # TODO: warn if evals are significantly negative
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
61
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\matmul_lazy_tensor.py in evaluate(self)
114 @cached
115 def evaluate(self):
--> 116 return torch.matmul(self.left_lazy_tensor.evaluate(), self.right_lazy_tensor.evaluate())
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
61
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\matmul_lazy_tensor.py in evaluate(self)
114 @cached
115 def evaluate(self):
--> 116 return torch.matmul(self.left_lazy_tensor.evaluate(), self.right_lazy_tensor.evaluate())
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\memoize.py in g(self, *args, **kwargs)
57 kwargs_pkl = pickle.dumps(kwargs)
58 if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> 59 return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
60 return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
61
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\lazy_tensor.py in evaluate(self)
1146 eye = torch.eye(num_cols, dtype=self.dtype, device=self.device)
1147 eye = eye.expand(*self.batch_shape, num_cols, num_cols)
-> 1148 res = self.matmul(eye)
1149 return res
1150
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\lazy\interpolated_lazy_tensor.py in matmul(self, tensor)
407 # right_interp^T * tensor
408 base_size = self.base_lazy_tensor.size(-1)
--> 409 right_interp_res = left_t_interp(self.right_interp_indices, self.right_interp_values, tensor, base_size)
410
411 # base_lazy_tensor * right_interp^T * tensor
~\AppData\Local\Programs\Python\Python39\lib\site-packages\gpytorch\utils\interpolation.py in left_t_interp(interp_indices, interp_values, rhs, output_dim)
226 else:
227 cls = getattr(torch.sparse, type_name)
--> 228 summing_matrix = cls(summing_matrix_indices, summing_matrix_values, size)
229
230 # Sum up the values appropriately by performing sparse matrix multiplication
RuntimeError: expected scalar type Long but found Double
Expected Behavior
Very similar implementation with batch multitask works well. So I'll add its implementation for the comparison and some visualizations of data.
import math
import torch
import gpytorch
from matplotlib import pyplot as plt
train_x = torch.cartesian_prod(torch.linspace(-1, 1, 50), torch.linspace(-1, 1, 50))
train_y = torch.stack([
torch.sin(train_x[:,0] * (2 * math.pi)) * torch.sin(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
torch.sin(train_x[:,0] * (2 * math.pi)) * torch.cos(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
torch.cos(train_x[:,0] * (2 * math.pi)) * torch.sin(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
torch.cos(train_x[:,0] * (2 * math.pi)) * torch.cos(train_x[:, 1] * (2 * math.pi)) + torch.randn(train_x[:,0].size()) * 0.2,
], -1)
class BatchIndependentMultitaskGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super().__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([4]))
self.covar_module = gpytorch.kernels.GridInterpolationKernel(
gpytorch.kernels.SpectralMixtureKernel(
num_mixtures=10,
ard_num_dims=2,
batch_shape=torch.Size([4])
),
grid_size=10,
num_dims=2,
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal.from_batch_mvn(
gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
)
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=4)
model = BatchIndependentMultitaskGPModel(train_x, train_y, likelihood)
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
training_iterations = 2 if smoke_test else 50
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iterations):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
optimizer.step()
Visualization on each tasks
System information
Please complete the following information:
- Python Version : 3.9.6 (Runned on Jupyter notebook)
- GPyTorch Version : 1.6.0
- PyTorch Version : 1.10.0+cpu
- Windows 10
While I was testing this bug, I found that this happens when we have too many data points, without multi dimensional input.
- If we change train_x to
train_x = torch.cartesian_prod(torch.linspace(-1, 1, 5), torch.linspace(-1, 1, 5))
, the code works well. - Conversely, if we use one dimension data but with large dataset, we have same bug.
import torch
import gpytorch
train_x = torch.linspace(-1, 1, 500)
train_y = torch.stack([
torch.sin(train_x * (2 * math.pi)) * torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
torch.sin(train_x * (2 * math.pi)) * torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
torch.cos(train_x * (2 * math.pi)) * torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
torch.cos(train_x * (2 * math.pi)) * torch.cos(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2,
], -1)
class MultitaskGPModel(gpytorch.models.ExactGP):
def __init__(self, train_x, train_y, likelihood):
super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
self.mean_module = gpytorch.means.MultitaskMean(
gpytorch.means.ConstantMean(), num_tasks=4
)
self.covar_module = gpytorch.kernels.MultitaskKernel(
gpytorch.kernels.GridInterpolationKernel(
gpytorch.kernels.RBFKernel(),
grid_size=10, num_dims=1
), num_tasks=4, rank=0
)
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)
likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=4)
model = MultitaskGPModel(train_x, train_y, likelihood)
# this is for running the notebook in our testing framework
# Find optimal model hyperparameters
model.train()
likelihood.train()
# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1) # Includes GaussianLikelihood parameters
# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)
for i in range(training_iterations):
optimizer.zero_grad()
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()
print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iterations, loss.item()))
optimizer.step()
So I'll fix the title of issue from multi-dimensional to large.
I just encountered the same problem with 2-dim input and 2-dim output for 5000 training points. I got the grid size from gpytorch.utils.grid.choose_grid_size(inputs_trn, 1.0)
and created the covar_module
like this
base_kernel = gpytorch.kernels.MaternKernel()
kernel = gpytorch.kernels.GridInterpolationKernel(kernel, grid_size, num_dims=num_tasks)
covar_module = gpytorch.kernels.MultitaskKernel(base_kernel, num_tasks, kernel_rank)
Yeah - this looks like a bug in GPyTorch/LinearOperator. I do not have the bandwidth to investigate this at the moment, so I'm looking for someone who can put up a PR to fix this.
I took a look into this error and found the root cause for the type error is under KroneckerProductAddedDiagLinearOperator._solve in LinearOperator. On line 177-181, we try to convert the lt and dlt matrices to torch.float64.
# again we perform the solve in double precision for numerical stability issues
# TODO: Use fp64 registry once #1213 is addressed
rhs = rhs.to(symeig_dtype)
lt = self.linear_op.to(symeig_dtype)
dlt = self.diag_tensor.to(symeig_dtype)
But this operation could be bad for InterpolatedLinearOperator as it converts the indices matrices of the operator (left_interp_indices, right_interp_indices) to float64 too, and this would further cause a type error when we try to create a sparse matrix using the indices matrix.
As for the input size dependence of this type error, we can avoid getting this type error if we call Cholesky for the inverse quadratic term and logdet term in _linear_operator.inv_quad_logdet(), when settings.fast_computations.log_prob.off() == true
or input_size * output_dimension <= settings.max_cholesky_size.value()
, which is 800.
I want to collect some suggestions on what is the best way to fix this issue from people with more experience on this package. One straight-forward way is to avoid converting the index matrices to torch.float when calling to() in _linear_operator, but there might be some other easier way to fix this problem.
I want to collect some suggestions on what is the best way to fix this issue from people with more experience on this package. One straight-forward way is to avoid converting the index matrices to torch.float when calling to() in _linear_operator, but there might be some other easier way to fix this problem.
Yeah it seems like not converting the index matrices in InterpolatedLinearOperator
when calling to()
is something that we'd want anyway. We can basically just overwrite the to()
method on InterpolatedLinearOperator
to only move the base linear operator.
Would you be willing to help out and put up a PR for this on the linear_operator
repo?
Yeah it seems like not converting the index matrices in
InterpolatedLinearOperator
when callingto()
is something that we'd want anyway. We can basically just overwrite theto()
method onInterpolatedLinearOperator
to only move the base linear operator.Would you be willing to help out and put up a PR for this on the
linear_operator
repo?
Sure, I will work on it and add an unit test for that.