[Bug] Broken example `Exact GP Regression with Multiple GPU`
🐛 Bug
The Multi-Device example "Exact GP Regression with Multiple GPU" is currently producing a runtime error on a standard 8-v100 node. The error is Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!
To reproduce
Run the Exact GP Regression with Multiple GPUs notebook and it fails during the call to train
** Stack trace/error message **
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[7], [line 1](vscode-notebook-cell:?execution_count=7&line=1)
----> [1](vscode-notebook-cell:?execution_count=7&line=1) model, likelihood = train(train_x, train_y,
[2](vscode-notebook-cell:?execution_count=7&line=2) n_devices=n_devices, output_device=output_device,
[3](vscode-notebook-cell:?execution_count=7&line=3) preconditioner_size=100,
[4](vscode-notebook-cell:?execution_count=7&line=4) n_training_iter=20)
Cell In[6], [line 42](vscode-notebook-cell:?execution_count=6&line=42)
[39](vscode-notebook-cell:?execution_count=6&line=39) loss = -mll(output, train_y)
[40](vscode-notebook-cell:?execution_count=6&line=40) return loss
---> [42](vscode-notebook-cell:?execution_count=6&line=42) loss = closure()
[43](vscode-notebook-cell:?execution_count=6&line=43) loss.backward()
[45](vscode-notebook-cell:?execution_count=6&line=45) for i in range(n_training_iter):
Cell In[6], [line 39](vscode-notebook-cell:?execution_count=6&line=39)
[37](vscode-notebook-cell:?execution_count=6&line=37) optimizer.zero_grad()
[38](vscode-notebook-cell:?execution_count=6&line=38) output = model(train_x)
---> [39](vscode-notebook-cell:?execution_count=6&line=39) loss = -mll(output, train_y)
[40](vscode-notebook-cell:?execution_count=6&line=40) return loss
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/module.py:31, in Module.__call__(self, *inputs, **kwargs)
[30](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/module.py:30) def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> [31](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/module.py:31) outputs = self.forward(*inputs, **kwargs)
[32](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/module.py:32) if isinstance(outputs, list):
[33](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/module.py:33) return [_validate_module_outputs(output) for output in outputs]
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:82, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params, **kwargs)
[79](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:79) raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!")
[81](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:81) # Get the log prob of the marginal distribution
---> [82](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:82) res = output.log_prob(target)
[83](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:83) res = self._add_other_terms(res, params)
[85](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:85) # Scale by the amount of data we have
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/distributions/multivariate_normal.py:193, in MultivariateNormal.log_prob(self, value)
[191](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/distributions/multivariate_normal.py:191) # Get log determininant and first part of quadratic form
[192](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/distributions/multivariate_normal.py:192) covar = covar.evaluate_kernel()
--> [193](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/distributions/multivariate_normal.py:193) inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
[195](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/distributions/multivariate_normal.py:195) res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
[196](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/distributions/multivariate_normal.py:196) return res
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1709, in LinearOperator.inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad)
[1707](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1707) will_need_cholesky = False
[1708](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1708) if will_need_cholesky:
-> [1709](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1709) cholesky = CholLinearOperator(TriangularLinearOperator(self.cholesky()))
[1710](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1710) return cholesky.inv_quad_logdet(
[1711](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1711) inv_quad_rhs=inv_quad_rhs,
[1712](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1712) logdet=logdet,
[1713](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1713) reduce_inv_quad=reduce_inv_quad,
[1714](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1714) )
[1716](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1716) # Short circuit to inv_quad function if we're not computing logdet
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1311, in LinearOperator.cholesky(self, upper)
[1301](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1301) @_implements(torch.linalg.cholesky)
[1302](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1302) def cholesky(
[1303](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1303) self: Float[LinearOperator, "*batch N N"], upper: bool = False
[1304](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1304) ) -> Float[LinearOperator, "*batch N N"]: # returns TriangularLinearOperator
[1305](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1305) """
[1306](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1306) Cholesky-factorizes the LinearOperator.
[1307](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1307)
[1308](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1308) :param upper: Upper triangular or lower triangular factor (default: False).
[1309](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1309) :return: Cholesky factor (lower or upper triangular)
[1310](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1310) """
-> [1311](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1311) chol = self._cholesky(upper=False)
[1312](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1312) if upper:
[1313](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:1313) chol = chol._transpose_nonbatch()
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
[57](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:57) kwargs_pkl = pickle.dumps(kwargs)
[58](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:58) if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> [59](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:59) return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
[60](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:60) return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:521, in LinearOperator._cholesky(self, upper)
[518](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:518) if any(isinstance(sub_mat, KeOpsLinearOperator) for sub_mat in evaluated_kern_mat._args):
[519](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:519) raise RuntimeError("Cannot run Cholesky with KeOps: it will either be really slow or not work.")
--> [521](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:521) evaluated_mat = evaluated_kern_mat.to_dense()
[523](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:523) # if the tensor is a scalar, we can just take the square root
[524](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/_linear_operator.py:524) if evaluated_mat.size(-1) == 1:
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:59, in _cached.<locals>.g(self, *args, **kwargs)
[57](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:57) kwargs_pkl = pickle.dumps(kwargs)
[58](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:58) if not _is_in_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl):
---> [59](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:59) return _add_to_cache(self, cache_name, method(self, *args, **kwargs), *args, kwargs_pkl=kwargs_pkl)
[60](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/utils/memoize.py:60) return _get_from_cache(self, cache_name, *args, kwargs_pkl=kwargs_pkl)
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/sum_linear_operator.py:81, in SumLinearOperator.to_dense(self)
[79](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/sum_linear_operator.py:79) @cached
[80](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/sum_linear_operator.py:80) def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]:
---> [81](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/sum_linear_operator.py:81) return (sum(linear_op.to_dense() for linear_op in self.linear_ops)).contiguous()
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/sum_linear_operator.py:81, in <genexpr>(.0)
[79](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/sum_linear_operator.py:79) @cached
[80](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/sum_linear_operator.py:80) def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]:
---> [81](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/sum_linear_operator.py:81) return (sum(linear_op.to_dense() for linear_op in self.linear_ops)).contiguous()
File ~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/cat_linear_operator.py:384, in CatLinearOperator.to_dense(self)
[383](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/cat_linear_operator.py:383) def to_dense(self: Float[LinearOperator, "*batch M N"]) -> Float[Tensor, "*batch M N"]:
--> [384](https://vscode-remote+ssh-002dremote-002b104-002e171-002e203-002e247.vscode-resource.vscode-cdn.net/home/ubuntu/gpytorch/examples/02_Scalable_Exact_GPs/~/micromamba/envs/newenv/lib/python3.12/site-packages/linear_operator/operators/cat_linear_operator.py:384) return torch.cat([to_dense(L) for L in self.linear_ops], dim=self.cat_dim)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument tensors in method wrapper_CUDA_cat)
Expected Behavior
The example should run with multiple GPUs
System information
Please complete the following information:
- GPytorch version 1.12
- PyTorch version 2.4.0+cu121
- Ubuntu 22.04.3 LTS
- GPU node provisioned via Lambda Labs 8x Tesla V100 (16 GB), 92 CPU cores, 460.1 GB RAM, 6.5 TB SSD
Additional context
Here is the output of nvidia-smi in case GPUs or drivers are relevant
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 Tesla V100-SXM2-16GB On | 00000000:00:04.0 Off | 0 |
| N/A 39C P0 55W / 300W | 1033MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 Tesla V100-SXM2-16GB On | 00000000:00:05.0 Off | 0 |
| N/A 40C P0 59W / 300W | 1061MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 2 Tesla V100-SXM2-16GB On | 00000000:00:06.0 Off | 0 |
| N/A 44C P0 62W / 300W | 965MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 3 Tesla V100-SXM2-16GB On | 00000000:00:07.0 Off | 0 |
| N/A 38C P0 55W / 300W | 969MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 4 Tesla V100-SXM2-16GB On | 00000000:00:08.0 Off | 0 |
| N/A 40C P0 55W / 300W | 1041MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 5 Tesla V100-SXM2-16GB On | 00000000:00:09.0 Off | 0 |
| N/A 43C P0 55W / 300W | 1057MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 6 Tesla V100-SXM2-16GB On | 00000000:00:0A.0 Off | 0 |
| N/A 43C P0 56W / 300W | 1009MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 7 Tesla V100-SXM2-16GB On | 00000000:00:0B.0 Off | 0 |
| N/A 41C P0 59W / 300W | 985MiB / 16384MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
| 0 N/A N/A 6876 C ...u/micromamba/envs/newenv/bin/python 1030MiB |
| 1 N/A N/A 6876 C ...u/micromamba/envs/newenv/bin/python 1058MiB |
| 2 N/A N/A 6876 C ...u/micromamba/envs/newenv/bin/python 962MiB |
| 3 N/A N/A 6876 C ...u/micromamba/envs/newenv/bin/python 966MiB |
| 4 N/A N/A 6876 C ...u/micromamba/envs/newenv/bin/python 1038MiB |
| 5 N/A N/A 6876 C ...u/micromamba/envs/newenv/bin/python 1054MiB |
| 6 N/A N/A 6876 C ...u/micromamba/envs/newenv/bin/python 1006MiB |
| 7 N/A N/A 6876 C ...u/micromamba/envs/newenv/bin/python 982MiB |
+---------------------------------------------------------------------------------------+
Hi Anthony, Have you fixed this bug? I had the same bug when running the same example code. Thank you.
Dear all,
I am experiencing a similar bug.
- The Google Drive URL leads to a 404 error.
- The training method leads to the following RuntimeError.
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 model, likelihood = train(train_x, train_y,
2 n_devices=n_devices, output_device=output_device,
3 preconditioner_size=100,
4 n_training_iter=20)
Cell In[5], line 42, in train(train_x, train_y, n_devices, output_device, preconditioner_size, n_training_iter)
39 loss = -mll(output, train_y)
40 return loss
---> 42 loss = closure()
43 loss.backward()
45 for i in range(n_training_iter):
Cell In[5], line 39, in train.<locals>.closure()
37 optimizer.zero_grad()
38 output = model(train_x)
---> 39 loss = -mll(output, train_y)
40 return loss
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/gpytorch/module.py:82, in Module.__call__(self, *inputs, **kwargs)
81 def __call__(self, *inputs, **kwargs) -> Union[Tensor, Distribution, LinearOperator]:
---> 82 outputs = self.forward(*inputs, **kwargs)
83 if isinstance(outputs, list):
84 return [_validate_module_outputs(output) for output in outputs]
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:82, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params, **kwargs)
79 raise ValueError("NaN observation policy 'fill' is not supported by ExactMarginalLogLikelihood!")
81 # Get the log prob of the marginal distribution
---> 82 res = output.log_prob(target)
83 res = self._add_other_terms(res, params)
85 # Scale by the amount of data we have
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/gpytorch/distributions/multivariate_normal.py:250, in MultivariateNormal.log_prob(self, value)
248 # Get log determininant and first part of quadratic form
249 covar = covar.evaluate_kernel()
--> 250 inv_quad, logdet = covar.inv_quad_logdet(inv_quad_rhs=diff.unsqueeze(-1), logdet=True)
252 res = -0.5 * sum([inv_quad, logdet, diff.size(-1) * math.log(2 * math.pi)])
253 return res
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1756, in LinearOperator.inv_quad_logdet(self, inv_quad_rhs, logdet, reduce_inv_quad)
1753 if inv_quad_rhs is not None:
1754 args = [inv_quad_rhs] + list(args)
-> 1756 preconditioner, precond_lt, logdet_p = self._preconditioner()
1757 if precond_lt is None:
1758 from linear_operator.operators.identity_linear_operator import IdentityLinearOperator
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/operators/added_diag_linear_operator.py:126, in AddedDiagLinearOperator._preconditioner(self)
124 if self._q_cache is None:
125 max_iter = settings.max_preconditioner_size.value()
--> 126 self._piv_chol_self = self._linear_op.pivoted_cholesky(rank=max_iter)
127 if torch.any(torch.isnan(self._piv_chol_self)).item():
128 warnings.warn(
129 "NaNs encountered in preconditioner computation. Attempting to continue without preconditioning.",
130 NumericalWarning,
131 )
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:1973, in LinearOperator.pivoted_cholesky(self, rank, error_tol, return_pivots)
1952 r"""
1953 Performs a partial pivoted Cholesky factorization of the (positive definite) LinearOperator.
1954 :math:`\mathbf L \mathbf L^\top = \mathbf K`.
(...) 1970 https://www.sciencedirect.com/science/article/pii/S0168927411001814
1971 """
1972 func = PivotedCholesky.apply
-> 1973 res, pivots = func(self.representation_tree(), rank, error_tol, *self.representation())
1975 if return_pivots:
1976 return res, pivots
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/torch/autograd/function.py:575, in Function.apply(cls, *args, **kwargs)
572 if not torch._C._are_functorch_transforms_active():
573 # See NOTE: [functorch vjp and autograd interaction]
574 args = _functorch.utils.unwrap_dead_wrappers(args)
--> 575 return super().apply(*args, **kwargs) # type: ignore[misc]
577 if not is_setup_ctx_defined:
578 raise RuntimeError(
579 "In order to use an autograd.Function with functorch transforms "
580 "(vmap, grad, jvp, jacrev, ...), it must override the setup_context "
581 "staticmethod. For more details, please see "
582 "https://pytorch.org/docs/main/notes/extending.func.html"
583 )
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/functions/_pivoted_cholesky.py:78, in PivotedCholesky.forward(ctx, representation_tree, max_iter, error_tol, *matrix_args)
75 # Populater L[... m:, m] with L[..., m:, m] * L[..., m, m].sqrt()
76 if m + 1 < matrix_shape[-1]:
77 # Get next row of the permuted matrix
---> 78 row = apply_permutation(matrix, pi_m.unsqueeze(-1), right_permutation=None).squeeze(-2)
79 pi_i = permutation[..., m + 1 :].contiguous()
81 L_m_new = row.gather(-1, pi_i)
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/utils/permutation.py:80, in apply_permutation(matrix, left_permutation, right_permutation)
76 right_permutation = torch.arange(matrix.size(-1), device=matrix.device)
78 # Apply permutations
79 return to_dense(
---> 80 matrix.__getitem__(
81 (
82 *batch_idx,
83 left_permutation.unsqueeze(-1),
84 right_permutation.unsqueeze(-2),
85 )
86 )
87 )
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/operators/_linear_operator.py:2855, in LinearOperator.__getitem__(self, index)
2849 # Convert all indices into tensor indices
2850 (
2851 *new_batch_indices,
2852 new_row_index,
2853 new_col_index,
2854 ) = _convert_indices_to_tensors(self, flattened_orig_indices)
-> 2855 res = self._get_indices(new_row_index, new_col_index, *new_batch_indices)
2856 # Now un-flatten tensor indices
2857 if len(tensor_index_shape) > 1: # Do we need to unflatten?
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/operators/cat_linear_operator.py:213, in CatLinearOperator._get_indices(self, row_index, col_index, *batch_indices)
210 for linear_op_idx, sub_index in zip(linear_op_indices, sub_indices):
211 sub_index[self.cat_dim] = sub_index[self.cat_dim] - self.cat_dim_cum_sizes[linear_op_idx]
--> 213 res_list = [
214 linear_op._get_indices(sub_index[-2], sub_index[-1], *sub_index[:-2])
215 for linear_op, sub_index in zip(linear_ops, sub_indices)
216 ]
217 if len(res_list) == 1:
218 return res_list[0].view(target_shape).to(self.device)
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/operators/cat_linear_operator.py:214, in <listcomp>(.0)
210 for linear_op_idx, sub_index in zip(linear_op_indices, sub_indices):
211 sub_index[self.cat_dim] = sub_index[self.cat_dim] - self.cat_dim_cum_sizes[linear_op_idx]
213 res_list = [
--> 214 linear_op._get_indices(sub_index[-2], sub_index[-1], *sub_index[:-2])
215 for linear_op, sub_index in zip(linear_ops, sub_indices)
216 ]
217 if len(res_list) == 1:
218 return res_list[0].view(target_shape).to(self.device)
File /central/home/jyzhao/python_env/gpytorch/gpytorch/lib/python3.11/site-packages/linear_operator/operators/dense_linear_operator.py:50, in DenseLinearOperator._get_indices(self, row_index, col_index, *batch_indices)
48 def _get_indices(self, row_index: IndexType, col_index: IndexType, *batch_indices: IndexType) -> torch.Tensor:
49 # Perform the __getitem__
---> 50 res = self.tensor[(*batch_indices, row_index, col_index)]
51 return res
RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)