gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Bug] Fantasy Models for Multitask GPs are broken

Open ancorso opened this issue 1 year ago • 5 comments

🐛 Bug

Getting a fantasy model for a simple multi-task GP throws an error

To reproduce

Here is a minimum working example of the bug

import torch
import gpytorch

class MultitaskGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, n_tasks):
        super(MultitaskGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.MultitaskMean(
            gpytorch.means.ConstantMean(), num_tasks=n_tasks
        )
        self.covar_module = gpytorch.kernels.MultitaskKernel(
            gpytorch.kernels.RBFKernel(), num_tasks=n_tasks, rank=1
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)


input_dim = 1
output_dim = 2
n_train = 10
train_x = torch.randn(n_train, input_dim)
train_y = torch.randn(n_train, output_dim)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=output_dim)
model = MultitaskGPModel(train_x, train_y, likelihood, output_dim)

model.train()
model.eval()

# get a posterior to fill in caches
model(torch.randn(n_train, input_dim))

# Generate some new data and get fantasy model
n_new = 5
new_x = torch.randn(n_new, input_dim)
new_y = torch.randn(n_new, output_dim)

model.get_fantasy_model(new_x, new_y)

** Stack trace/error message **

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/acorso/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/models/exact_gp.py", line 239, in get_fantasy_model
    new_model.prediction_strategy = old_pred_strat.get_fantasy_strategy(
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/acorso/micromamba/envs/newenv/lib/python3.12/site-packages/gpytorch/models/exact_prediction_strategies.py", line 196, in get_fantasy_strategy
    small_system_rhs = targets - fant_mean - ftcm
                       ~~~~~~~~~~~~~~~~~~~~^~~~~~
RuntimeError: The size of tensor a (2) must match the size of tensor b (10) at non-singleton dimension 1

Expected Behavior

The fantasy model with the appropriately updated cache should be returned

System information

  • gpytorch 1.12
  • torch 2.4.0+cu121
  • MacOS

Additional context

This has already been a topic of discussion https://github.com/cornellius-gp/gpytorch/issues/800 and https://github.com/cornellius-gp/gpytorch/pull/805 and a PR was merged that supposedly implemented this feature https://github.com/cornellius-gp/gpytorch/pull/2317. However, the test that was added only works because only a single additional datapoint was added to produce the fantasy model. If you switch n_new=1 in the example I provide above, it also runs without error but I'm skeptical that the right thing is happening, if it doesn't work for more than 1 additional point.

ancorso avatar Sep 01 '24 17:09 ancorso

With some investigation, it seems like this line https://github.com/cornellius-gp/gpytorch/blob/44993efcc180bdbdeaaf2107c7cc1ba532b2da9b/gpytorch/models/exact_prediction_strategies.py#L194 is producing a tensor that is 1D and length m*d_out where d_out is the output dimension of the GP. In the next line https://github.com/cornellius-gp/gpytorch/blob/44993efcc180bdbdeaaf2107c7cc1ba532b2da9b/gpytorch/models/exact_prediction_strategies.py#L196 the terms target and fant_mean are tensors of size m x d_out. So this seems like it would be an easy fix with respect to reshaping, but I'm not sure which is the correct shape to use. If someone can weigh in on what is correct here, I am happy to submit a PR with a new test example that covers this case. @gpleiss?

ancorso avatar Sep 01 '24 17:09 ancorso

@ancorso we are in the middle of a reworking of the prediction strategies code (timeline tbd) for a 2.0 release. However, we'd accept a bugfix PR for the time being (as long as it's not too much work on your end!)

gpleiss avatar Sep 03 '24 01:09 gpleiss

cc @hvarfner

Balandat avatar Sep 05 '24 03:09 Balandat

@gpleiss We have put together a bugfix PR here https://github.com/cornellius-gp/gpytorch/pull/2587, which passes tests that we ran locally. If you have some time, let us know what you

williamjsdavis avatar Sep 12 '24 22:09 williamjsdavis

Any updates on this? Thank you!

aadityacs avatar Nov 12 '25 15:11 aadityacs