gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Bug] MultiDeviceKernel breaks for non-exact GP regression

Open markoris opened this issue 9 months ago • 1 comments

🐛 Bug

When using the MultiDeviceKernel for multi-GPU training outside of the ExactGP tutorial scenario, the model evaluation throws an error due to tensors existing on more than one GPU.

To reproduce

** Code snippet to reproduce **

  1 import torch
  2 import gpytorch
  3 
  4 '''
  5 Significantly derived from https://github.com/ianhill60/gprsampledataset/blob/5982ea5d991ee6fb733627211dd60e447a6750e5/VariationalMultitaskApplication.py
  6 '''
  7 
  8 class MultitaskGPModel(gpytorch.models.ApproximateGP):
  9     def __init__(self, train_x, num_inducing_pts, num_tasks, num_latents):
 10 
 11         inducing_points = (train_x[torch.randperm(min(1000 * 100, len(train_x)))[0:num_inducing_pts], :])
 12 
 13         if torch.cuda.is_available():
 14           inducing_points = inducing_points.to('cuda:0')
 15 
 16         variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
 17             inducing_points.size(0), batch_shape=torch.Size([num_latents])
 18         )
 19 
 20         variational_strategy = gpytorch.variational.LMCVariationalStrategy(
 21             gpytorch.variational.VariationalStrategy(
 22                 self, inducing_points, variational_distribution, learn_inducing_locations=True
 23             ),
 24             num_tasks=num_tasks,
 25             num_latents=num_latents,
 26             latent_dim=-1
 27         )
 28 
 29         super().__init__(variational_strategy)
 30 
 31         self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([num_latents]))
 32         base_covar_module = gpytorch.kernels.ScaleKernel(
 33             gpytorch.kernels.RBFKernel(batch_shape=torch.Size([num_latents])),
 34             batch_shape=torch.Size([num_latents])
 35         )
 36         self.covar_module = gpytorch.kernels.MultiDeviceKernel(
 37             base_covar_module, device_ids=range(torch.cuda.device_count()),
 38             output_device=torch.device('cuda:0')
 39         )
 40 
 41     def forward(self, x):
 42 
 43         mean_x = self.mean_module(x)
 44         covar_x = self.covar_module(x)
 45         return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
 46 
 47 class Interpolator(object):
 48 
 49     def __init__(self, num_inducing_pts=1000, num_tasks=264, num_latents=6):
 50 
 51         self.X = torch.rand([100, 6])
 52         self.y = torch.rand([100, num_tasks])
 53 
 54         self.num_inducing_pts = num_inducing_pts
 55         self.num_tasks = num_tasks
 56         self.num_latents = num_latents
 57 
 58         self.device = torch.device('cuda:0') if torch.cuda.is_available() else 'cpu'
 59         self.num_epochs = 100
 60 
 61     def setup_GP(self):
 62 
 63         self.model = MultitaskGPModel(self.X, self.num_inducing_pts, self.num_tasks, self.num_latents).double().to(self.device)
 64         self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=self.num_tasks).to(self.device)
 65 
 66         self.model.train()
 67         self.likelihood.train()
 68 
 69         self.optimizer = torch.optim.Adam([
 70             {'params': self.model.parameters()},
 71             {'params': self.likelihood.parameters()},
 72         ], lr=.001)
 73 
 74         self.mll = gpytorch.mlls.VariationalELBO(self.likelihood, self.model, num_data=self.y.size(0)).cuda()
 75 
 76     def preprocessing(self):
 77 
 78         train_dataset = torch.utils.data.TensorDataset(self.X, self.y)
 79         self.train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, drop_last=False)
 80 
 81     def train(self):
 82 
 83             val_loss = []
 84             for i in range(self.num_epochs):
 85               batch_losses = []
 86               for (x_batch, y_batch) in self.train_loader:
 87                 x_batch = x_batch.to(self.device)
 88                 y_batch = y_batch.to(self.device)
 89                 self.optimizer.zero_grad()
 90                 output = self.model(x_batch)
 91                 loss = -self.mll(output, y_batch)
 92                 loss.backward()
 93                 self.optimizer.step()
 94 
 95                 batch_losses.append(loss.item())
 96               print('Iter %d/%d - Loss: %.3f   noise: %.3f' % (
 97                             i + 1, self.num_epochs, batch_losses[-1],
 98                             self.likelihood.noise.item()
 99                         )
100                    )
101 
102 if __name__=='__main__':
103     intp = Interpolator()
104     intp.preprocessing()
105     intp.setup_GP()
106     intp.train()

** Stack trace/error message **

Traceback (most recent call last):
  File "/users/mristic/gpytorch_for_placement/check_multigpu.py", line 106, in <module>
    intp.train()
  File "/users/mristic/gpytorch_for_placement/check_multigpu.py", line 90, in train
    output = self.model(x_batch)
  File "/users/mristic/.local/lib/python3.10/site-packages/gpytorch/models/approximate_gp.py", line 114, in __call__
    return self.variational_strategy(inputs, prior=prior, **kwargs)
  File "/users/mristic/.local/lib/python3.10/site-packages/gpytorch/variational/lmc_variational_strategy.py", line 197, in __call__
    latent_dist = self.base_variational_strategy(x, prior=prior, **kwargs)
  File "/users/mristic/.local/lib/python3.10/site-packages/gpytorch/variational/variational_strategy.py", line 272, in __call__
    return super().__call__(x, prior=prior, **kwargs)
  File "/users/mristic/.local/lib/python3.10/site-packages/gpytorch/variational/_variational_strategy.py", line 347, in __call__
    return super().__call__(
  File "/users/mristic/.local/lib/python3.10/site-packages/gpytorch/module.py", line 31, in __call__
    outputs = self.forward(*inputs, **kwargs)
  File "/users/mristic/.local/lib/python3.10/site-packages/gpytorch/variational/variational_strategy.py", line 197, in forward
    induc_data_covar = full_covar[..., :num_induc, num_induc:].to_dense()
  File "/users/mristic/.local/lib/python3.10/site-packages/gpytorch/utils/memoize.py", line 59, in g
    return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
  File "/users/mristic/.local/lib/python3.10/site-packages/gpytorch/lazy/lazy_evaluated_kernel_tensor.py", line 410, in to_dense
    return self.evaluate_kernel().to_dense()
  File "/users/mristic/.local/lib/python3.10/site-packages/linear_operator/operators/cat_linear_operator.py", line 384, in to_dense
    return torch.cat([to_dense(L) for L in self.linear_ops], dim=self.cat_dim)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)

Expected Behavior

Successfully training on multiple GPUs.

System information

Please complete the following information:

  • GPyTorch Version 1.14
  • PyTorch Version 2.5.1
  • SUSE Linux Enterprise Server 15 SP4

markoris avatar Mar 07 '25 20:03 markoris

I don't have the hardware to test this myself, but does the same issue happen for single-output non-exact GP models when using multiple GPUs?

mrlj-hash avatar Apr 08 '25 02:04 mrlj-hash