botorch
botorch copied to clipboard
[Feature Request] InputDataWarnings are are confusing
Motivation
It's common to see warnings like this:
import torch
from botorch.models import SingleTaskGP
x = torch.linspace(-5, 10, 100).unsqueeze(-1)
model = SingleTaskGP(train_X=x, train_Y=x)
/botorch/models/utils/assorted.py:173: InputDataWarning:
Input data is not contained to the unit cube. Please consider min-max scaling the input data.
/botorch/models/utils/assorted.py:201: InputDataWarning:
Input data is not standardized. Please consider scaling the input to zero mean and unit variance.
Pitch
To improve these warnings,
- Make it clear which of these is about
train_Xand which is abouttrain_Y. (Should the y data be "outcome" data?) - Suggest using transforms to fix this. In this case, a better implementation would be
import torch
from botorch.models import SingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
x = torch.linspace(-5, 10, 100).unsqueeze(-1)
model = SingleTaskGP(
train_X=x,
train_Y=x,
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=1)
)
Are you willing to open a pull request? (See CONTRIBUTING)
Yes, but this is pretty easy so it would make a good first task for a newcomer
Hi, I am newcomer, can I take this?
Hi, I am newcomer, can I take this?
Most certainly, thanks! Let us know if you need any help, but most of what you need should be here: https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md
Hi, I am newcomer, can I take this?
Definitely! The one thing that may be tricky here is that BoTorch requires 100% test coverage and requires all tests to pass; to achieve that, you'll need to update existing tests such as this one. We also suppress the input data warnings in unit tests here, so that may need to be updated too.
We also suppress the input data warnings in unit tests here, so that may need to be updated too.
These warnings can be locally brought back and checked for using something like
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always")
function_call_to_produce_the_warning(...)
self.assertTrue(any(expected_warning_msg in str(w) for w in ws))
I'd prefer a local solution like this to bringing them back for all of the test suite.