botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] Batched Outputs with non-batched inputs produces incorrect shape with posterior

Open wjmaddox opened this issue 3 years ago • 10 comments

🐛 Bug

Looks like the normalize transform is incorrectly double-batching the posterior but only after training the model. Not entirely sure where the bug is in here:

To reproduce

** Code snippet to reproduce **

import torch
from botorch.models import SingleTaskGP
from botorch.models.transforms import Normalize
from botorch.optim.fit import fit_gpytorch_torch
from gpytorch.mlls import ExactMarginalLogLikelihood

train_x = torch.randn(170, 1)
train_y = torch.randn(170, 8)
test_x = torch.randn(80, 1)

model = SingleTaskGP(train_x, train_y, input_transform=Normalize(d=1))
mll = ExactMarginalLogLikelihood(model.likelihood, model)

with torch.no_grad():
    posterior = model.posterior(train_x)
    print(posterior.mean.shape) # 170 x 8 as expected

fit_gpytorch_torch(mll); # training seems required to reproduce bug

with torch.no_grad():
    posterior = model.posterior(train_x)
    print(posterior.mean.shape) # 8 x 170 x 8 which is incorrect

** Stack trace/error message **

See code.

Expected Behavior

would expect the posterior after training to be 170 x 8.

System information

Please complete the following information:

  • botorch @ 9a93afb407
  • gpytorch @ master
  • pytorch 1.11
  • linux

Additional context

Happy to try to debug tomorrow.

wjmaddox avatar Apr 19 '22 00:04 wjmaddox

That's a weird issue. I can't seem to replicate it though - might be isolated to linux, I tried on mac.

saitcakmak avatar Apr 19 '22 00:04 saitcakmak

Not sure what's going on there, but my first guess would be that this has something to do with the batch-model-to-single-model-to-batch-model conversion that we're doing by default during the fitting: https://github.com/pytorch/botorch/blob/main/botorch/fit.py#L79-L109

Balandat avatar Apr 19 '22 02:04 Balandat

That's a weird issue. I can't seem to replicate it though - might be isolated to linux, I tried on mac.

Huh that is interesting...

Balandat avatar Apr 19 '22 02:04 Balandat

Weird, I can't replicate on my mac, but could replicate on another linux server before pulling to botorch (from 353f3764) / gpytorch master (from 32cde571). Then, a fresh install of python on my current machine fixed it as well. Closing for now.

but my first guess would be that this has something to do with the batch-model-to-single-model-to-batch-model conversion that we're doing by default during the fitting: https://github.com/pytorch/botorch/blob/main/botorch/fit.py#L79-L109

Would have thought it was this but I wasn't using fit_gpytorch_model but rather fit_gpytorch_torch.

wjmaddox avatar Apr 19 '22 12:04 wjmaddox

Did you figure out what’s wrong here?

Balandat avatar Apr 19 '22 13:04 Balandat

I did not. Just tried replicating on botorch (from https://github.com/pytorch/botorch/commit/353f37649fa8d90d881e8ea20c11986b15723ef1) / gpytorch master (from 32cde571) and was able to reproduce on my mac. Actually the first posterior call fails there as well -- will debug from there.

wjmaddox avatar Apr 19 '22 13:04 wjmaddox

I don't have a full understanding of what's going on yet but it seems to be related to how botorch internally is handling the batching:

with gpytorch.settings.debug(False):
    model(train_x)
    print(model.input_transform.mins.shape) # 1 x 1 as expected

model(*model.train_inputs)
model.input_transform.mins.shape # now 8 x 1 x 1

this means that as we move through the posterior call, then the input gets some extra batch dimensions. Tracing through, if the input starts as 80 x 1, it becomes 8 x 80 x 1 (probably okay) after the input transform is performed here and then becomes 8 x 1 x 80 x 1 (not correct) because botorch thinks there's extra inputs here.

This seems to be fixed on main because the mins shape is 1 x 1 as expected.

wjmaddox avatar Apr 19 '22 14:04 wjmaddox

This seems to be fixed on main because the mins shape is 1 x 1 as expected.

Which main? Botorch or gpytorch? I recently landed a fix to gpytorch's lazy evaluated kernel tensor that could be relevant here: https://github.com/cornellius-gp/gpytorch/pull/1971

Balandat avatar Apr 19 '22 23:04 Balandat

Sorry for the imprecision, botorch main

wjmaddox avatar Apr 20 '22 00:04 wjmaddox

Hmm, then I don't know what's going on here...

Balandat avatar Apr 20 '22 01:04 Balandat

I'm currently not able to reproduce this on either a Mac or Linux. Maybe it fixed itself?

esantorella avatar Jan 31 '23 02:01 esantorella