gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Bug] `mean_cache` not attached with correct `args` in `get_fantasy_model`

Open JackBuck opened this issue 2 months ago • 2 comments

🐛 Bug

When a model is fantasized, the mean and covariance caches are recreated and reattached. However, the mean cache is attached with no args, while when it is called from DefaultPredictionStrategy.mean_cache, the settings.observation_nan_policy.value() is an arg. As a result, fantasy models recompute the mean cache more than necessary resulting in inefficient code, particularly when many fantasy models are repeatedly created in an optimization loop (for example).

To reproduce

MWE kindly adapted from #2631, however, beyond the set up I don't know if these issues are actually related.

Code snippet to reproduce

import torch
from gpytorch import settings as gpt_settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.models import ExactGP

torch.set_default_dtype(torch.double)

d = 10
mc_points = torch.rand(32, d, dtype=torch.double)


class SimpleGP(ExactGP):
    def __init__(self, train_inputs, train_targets):
        super().__init__(train_inputs, train_targets, GaussianLikelihood())
        self.mean_module = ConstantMean()
        self.covar_module = RBFKernel()

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)


gp = SimpleGP(
    train_inputs=torch.rand(256, d, dtype=torch.double),
    train_targets=torch.rand(256, dtype=torch.double),
).eval()
gp(torch.rand(5, d, dtype=torch.double))  # set the caches before fantasize.

print(gp.prediction_strategy._memoize_cache.keys())

X = torch.rand(128, 5, d, dtype=torch.double, requires_grad=False)
Y = torch.rand(128, 5, dtype=torch.double, requires_grad=False)

fantasy_model = gp.get_fantasy_model(inputs=X, targets=Y).eval()

print(fantasy_model.prediction_strategy._memoize_cache.keys())

Stdout

dict_keys([('mean_cache', ('ignore',), b'\x80\x04}\x94.')])
dict_keys([('mean_cache', (), b'\x80\x04}\x94.'), ('covar_cache', (), b'\x80\x04}\x94.')])

Expected Behavior

I would expect for the mean_cache on the fantasy model to be attached with the same key as it is on self. In the above case, this means 'mean_cache', ('ignore',), b'\x80\x04}\x94.' instead of 'mean_cache', (), b'\x80\x04}\x94.'.

I would perhaps also expect that get_fantasy_model should be aware of the observation_nan_policy setting.

System information

Please complete the following information:

  • GPyTorch Version: 1.15.dev37+g8433c0b86
  • PyTorch Version: 2.8.0+cu128
  • Computer OS: Ubuntu 20.04.6 LTS (Focal Fossa)

Additional context

Line where the mean_cache is accessed using an argument: https://github.com/cornellius-gp/gpytorch/blob/8433c0b86660d0f318e3b0b5175bcfc5f9967894/gpytorch/models/exact_prediction_strategies.py#L255

Line where the mean_cache is added during get_fantasy_model: https://github.com/cornellius-gp/gpytorch/blob/8433c0b86660d0f318e3b0b5175bcfc5f9967894/gpytorch/models/exact_prediction_strategies.py#L242

JackBuck avatar Oct 01 '25 18:10 JackBuck

Oof good catch. Any chance you could put up a PR?

gpleiss avatar Oct 03 '25 23:10 gpleiss

Sure! Will do next week. I'm not familiar with what the different nan-handling logic is though. Currently get_fantasy_strategy just has one path. Are you happy for it to stay that way? I.e. the cache key gets set to whatever the existing key is, regardless of the current observation_nan_policy setting, and the logic in get_fantasy_strategy remains unchanged.

JackBuck avatar Oct 04 '25 15:10 JackBuck