gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Bug] Posterior on fantasy models can be memory inefficient on batches

Open JackBuck opened this issue 7 months ago • 1 comments

🐛 Bug

When using a fantasy model with a batch dimension, and retrieving the posterior with a set using more batch dimensions, internally GPyTorch / PyTorch tries to create excessively large matrices.

To reproduce

The following example uses gpytorch.settings.fast_pred_var() which is what botorch uses and is the case which I've been investigating locally. However, it blows up with memory issues in a different place without this setting.

import torch
import gpytorch.settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP

torch.set_default_dtype(torch.double)


class SimpleGP(ExactGP):
    def __init__(self, train_inputs, train_targets):
        super().__init__(train_inputs, train_targets, GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = RBFKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


if __name__ == "__main__":
    d = 2
    n_train = 100
    gp = SimpleGP(
        train_inputs=torch.rand(n_train, d, dtype=torch.double),
        train_targets=torch.rand(n_train, dtype=torch.double),
    ).eval()
    gp(torch.rand(5, d, dtype=torch.double))  # set the caches before fantasize.

    num_fantasies = 64
    x = torch.rand(256, 1, d)
    y = torch.rand(num_fantasies, 256, 1)
    fantasy_model = gp.get_fantasy_model(x, y).eval()

    x_test = torch.rand(50, num_fantasies, 256, 1, d)
    with gpytorch.settings.fast_pred_var():
        fantasy_model(x_test)  # Tries to allocate a tensor of 64GB
Traceback (most recent call last):
  File ".../mwe_fantasize_memory_issue.py", line 43, in <module>
    fantasy_model(x_test)  # Tries to allocate a tensor of 64GB
    ~~~~~~~~~~~~~^^^^^^^^
  File ".../lib/python3.13/site-packages/gpytorch/models/exact_gp.py", line 345, in __call__
    ) = self.prediction_strategy.exact_prediction(full_mean, full_covar)
        ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 325, in exact_prediction
    self.exact_predictive_covar(test_test_covar, test_train_covar),
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 411, in exact_predictive_covar
    covar_inv_quad_form_root = self._exact_predictive_covar_inv_quad_form_root(precomputed_cache, test_train_covar)
  File ".../lib/python3.13/site-packages/gpytorch/models/exact_prediction_strategies.py", line 117, in _exact_predictive_covar_inv_quad_form_root
    return test_train_covar.matmul(precomputed_cache)
           ~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^
RuntimeError: [enforce fail at alloc_cpu.cpp:118] err == 0. DefaultCPUAllocator: can't allocate memory: you tried to allocate 66853273600 bytes. Error code 12 (Cannot allocate memory)

Expected Behavior

I would expect GPyTorch to be able to do this computation without excessive memory requests.

System information

  • GPyTorch Version: 1.14
  • PyTorch Version: 2.6.0+cpu
  • Computer OS: Ubuntu 20.04.6 LTS (Focal Fossa)

Additional context

The memory request occurs on the line return test_train_covar.matmul(precomputed_cache). The sizes of the matrices are:

  • test_train_covar: (50, 64, 256, 1, 101)
  • precomputed_cache: (256, 101, 101)

I believe that internally, torch.matmul is physically broadcasting precomputed_cache to shape (50, 64, 256, 101, 101) which causes the memory issue. I also believe that this is unnecessary (see this pytorch issue https://github.com/pytorch/pytorch/issues/154128). However, I am creating this issue in GPyTorch as well because (a) perhaps there is a simpler fix in gpytorch because we have better information on the dimensions of the tensors (although maybe not - I'm not sure what dimensions are possible), and (b) the memory issue occurs in a different place without the gpytorch.settings.fast_pred_var() setting (I think during cholesky factorisation) and my suggested fix in PyTorch would not fix that.

JackBuck avatar May 22 '25 18:05 JackBuck

Another related issue here: https://github.com/pytorch/botorch/issues/2310#issue-2265529284

Balandat avatar May 23 '25 00:05 Balandat