botorch
botorch copied to clipboard
input_transform Normalize does not seem to work properly with condition_on_observations
🐛 Bug
After running model.condition_on_observations(new_x, new_y)
, where the original model was instantiated with Normalize(d)
, that model fails during retraining. I believe this is a bug but I'm honestly not sure.
To reproduce
Step 1: initialize dummy data
import botorch
import numpy as np
import torch
np.random.seed(123)
torch.manual_seed(123)
# use regular spaced points on the interval [0, 1]
train_x = torch.linspace(0, 1, 15)
# training data needs to be explicitly multi-dimensional
train_x = train_x.unsqueeze(1)
# sample observed values and add some synthetic noise
train_y = torch.sin(train_x * (2 * np.pi)) + 0.15 * torch.randn_like(train_x)
Step 2: initialization/training, works just fine
model = botorch.models.SingleTaskGP(
train_X=train_x, train_Y=train_y, input_transform=Normalize(1, transform_on_eval=True)
)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
botorch.fit.fit_gpytorch_mll(mll)
Step 3: condition
new_x = torch.FloatTensor(np.array([1.25, 1.5]).reshape(-1, 1))
new_y = torch.FloatTensor(np.array([-1.0, -2.0]).reshape(-1, 1))
model = model.condition_on_observations(new_x, new_y)
Step 4: attempt retraining to further tune hyper parameters/length scales and whatnot
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
botorch.fit.fit_gpytorch_mll(mll) # fails
Stack trace/error message
MDNotImplementedError Traceback (most recent call last)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/utils/dispatcher.py:88, in Dispatcher.__call__(self, *args, **kwargs)
87 try:
---> 88 return func(*args, **kwargs)
89 except MDNotImplementedError:
90 # Traverses registered methods in order, yields whenever a match is found
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:320, in _fit_multioutput_independent(mll, _, __, sequential, **kwargs)
315 if ( # incompatible models
316 not sequential
317 or mll.model.num_outputs == 1
318 or mll.likelihood is not getattr(mll.model, "likelihood", None)
319 ):
--> 320 raise MDNotImplementedError # defer to generic
322 # TODO: Unpacking of OutcomeTransforms not yet supported. Targets are often
323 # pre-transformed in __init__, so try fitting with outcome_transform hidden
MDNotImplementedError:
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Input In [20], in <cell line: 2>()
1 mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood=model.likelihood, model=model)
----> 2 botorch.fit.fit_gpytorch_mll(mll)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:114, in fit_gpytorch_mll(mll, optimizer, optimizer_kwargs, **kwargs)
111 if optimizer is not None: # defer to per-method defaults
112 kwargs["optimizer"] = optimizer
--> 114 return dispatcher(
115 mll,
116 type(mll.likelihood),
117 type(mll.model),
118 optimizer_kwargs=optimizer_kwargs,
119 **kwargs,
120 )
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/utils/dispatcher.py:95, in Dispatcher.__call__(self, *args, **kwargs)
93 for func in funcs:
94 try:
---> 95 return func(*args, **kwargs)
96 except MDNotImplementedError:
97 pass
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/fit.py:240, in _fit_fallback(mll, _, __, optimizer, optimizer_kwargs, max_attempts, warning_filter, caught_exception_types, **ignore)
238 with catch_warnings(record=True) as warning_list, debug(True):
239 simplefilter("always", category=OptimizationWarning)
--> 240 mll, _ = optimizer(mll, **optimizer_kwargs)
242 # Resolve warning messages and determine whether or not to retry
243 done = True
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/fit.py:142, in fit_gpytorch_scipy(mll, bounds, method, options, track_iterations, approx_mll, scipy_objective, module_to_array_func, module_from_array_func)
140 cb = store_iteration if track_iterations else None
141 with gpt_settings.fast_computations(log_prob=approx_mll):
--> 142 res = minimize(
143 scipy_objective,
144 x0,
145 args=(mll, property_dict),
146 bounds=bounds,
147 method=method,
148 jac=True,
149 options=options,
150 callback=cb,
151 )
152 iterations = []
153 if track_iterations:
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_minimize.py:692, in minimize(fun, x0, args, method, jac, hess, hessp, bounds, constraints, tol, callback, options)
689 res = _minimize_newtoncg(fun, x0, args, jac, hess, hessp, callback,
690 **options)
691 elif meth == 'l-bfgs-b':
--> 692 res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
693 callback=callback, **options)
694 elif meth == 'tnc':
695 res = _minimize_tnc(fun, x0, args, jac, bounds, callback=callback,
696 **options)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_lbfgsb_py.py:308, in _minimize_lbfgsb(fun, x0, args, jac, bounds, disp, maxcor, ftol, gtol, eps, maxfun, maxiter, iprint, callback, maxls, finite_diff_rel_step, **unknown_options)
305 else:
306 iprint = disp
--> 308 sf = _prepare_scalar_function(fun, x0, jac=jac, args=args, epsilon=eps,
309 bounds=new_bounds,
310 finite_diff_rel_step=finite_diff_rel_step)
312 func_and_grad = sf.fun_and_grad
314 fortran_int = _lbfgsb.types.intvar.dtype
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:263, in _prepare_scalar_function(fun, x0, jac, args, bounds, epsilon, finite_diff_rel_step, hess)
259 bounds = (-np.inf, np.inf)
261 # ScalarFunction caches. Reuse of fun(x) during grad
262 # calculation reduces overall function evaluations.
--> 263 sf = ScalarFunction(fun, x0, args, grad, hess,
264 finite_diff_rel_step, bounds, epsilon=epsilon)
266 return sf
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:158, in ScalarFunction.__init__(self, fun, x0, args, grad, hess, finite_diff_rel_step, finite_diff_bounds, epsilon)
155 self.f = fun_wrapped(self.x)
157 self._update_fun_impl = update_fun
--> 158 self._update_fun()
160 # Gradient evaluation
161 if callable(grad):
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:251, in ScalarFunction._update_fun(self)
249 def _update_fun(self):
250 if not self.f_updated:
--> 251 self._update_fun_impl()
252 self.f_updated = True
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:155, in ScalarFunction.__init__.<locals>.update_fun()
154 def update_fun():
--> 155 self.f = fun_wrapped(self.x)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py:137, in ScalarFunction.__init__.<locals>.fun_wrapped(x)
133 self.nfev += 1
134 # Send a copy because the user may overwrite it.
135 # Overwriting results in undefined behaviour because
136 # fun(self.x) will change self.x, with the two no longer linked.
--> 137 fx = fun(np.copy(x), *args)
138 # Make sure the function returns a true scalar
139 if not np.isscalar(fx):
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:76, in MemoizeJac.__call__(self, x, *args)
74 def __call__(self, x, *args):
75 """ returns the the function value """
---> 76 self._compute_if_needed(x, *args)
77 return self._value
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/scipy/optimize/_optimize.py:70, in MemoizeJac._compute_if_needed(self, x, *args)
68 if not np.all(x == self.x) or self._value is None or self.jac is None:
69 self.x = np.asarray(x).copy()
---> 70 fg = self.fun(x, *args)
71 self.jac = fg[1]
72 self._value = fg[0]
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:227, in _scipy_objective_and_grad(x, mll, property_dict)
225 loss = -mll(*args).sum()
226 except RuntimeError as e:
--> 227 return _handle_numerical_errors(error=e, x=x)
228 loss.backward()
230 i = 0
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:256, in _handle_numerical_errors(error, x)
250 if (
251 isinstance(error, NanError)
252 or "singular" in error_message # old pytorch message
253 or "input is not positive-definite" in error_message # since pytorch #63864
254 ):
255 return float("nan"), np.full_like(x, "nan")
--> 256 raise error
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/botorch/optim/utils.py:225, in _scipy_objective_and_grad(x, mll, property_dict)
223 output = mll.model(*train_inputs)
224 args = [output, train_targets] + _get_extra_mll_args(mll)
--> 225 loss = -mll(*args).sum()
226 except RuntimeError as e:
227 return _handle_numerical_errors(error=e, x=x)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/module.py:30, in Module.__call__(self, *inputs, **kwargs)
29 def __call__(self, *inputs, **kwargs):
---> 30 outputs = self.forward(*inputs, **kwargs)
31 if isinstance(outputs, list):
32 return [_validate_module_outputs(output) for output in outputs]
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params)
62 # Get the log prob of the marginal distribution
63 output = self.likelihood(function_dist, *params)
---> 64 res = output.log_prob(target)
65 res = self._add_other_terms(res, params)
67 # Scale by the amount of data we have
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/gpytorch/distributions/multivariate_normal.py:147, in MultivariateNormal.log_prob(self, value)
145 def log_prob(self, value):
146 if settings.fast_computations.log_prob.off():
--> 147 return super().log_prob(value)
149 if self._validate_args:
150 self._validate_sample(value)
File ~/miniforge3/envs/py3.9/lib/python3.9/site-packages/torch/distributions/multivariate_normal.py:211, in MultivariateNormal.log_prob(self, value)
209 if self._validate_args:
210 self._validate_sample(value)
--> 211 diff = value - self.loc
212 M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
213 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
RuntimeError: The size of tensor a (17) must match the size of tensor b (15) at non-singleton dimension 0
Expected Behavior
I think the second training procedure is supposed to work, right? It would seem sensible that Normalize
would be updated with the new training information as passed during conditioning.
System information
BoTorch version: 0.7.2
GPyTorch version: 1.9.0
Torch version: 1.12.0
Computer OS: Mac M1 Max OS version 12.5.1
Hmm interesting it might be that some of the input transformations may not play well with the condition_on_observations
call? @saitcakmak you've probably got the best understanding of the input transforms, could you please take a look?
I assume this works fine if you don't use an input transform?
@Balandat that is correct. In fact, it actually works with the Standardize
output transform.
The issue seems to be that the model._original_train_inputs
is a 15 x 1
-dim tensor (even after conditioning), so when we call mll.train
, the model.trian_inputs
get reverted back to model._original_train_inputs
, losing the X
's we just conditioned on.
This will be fixed in #1372, as it gets rid of the _original_train_inputs
and all the related hacks. In the mean time, you can achieve the warm-starting behavior by creating a new model and loading the hyper-parameters using new_model.load_state_dict(old_model.state_dict())
.
@saitcakmak thank you, I'll give your hack a try for now!
@saitcakmak so just to be clear as to what you mean here. The idea would be to create an entirely new model with all of the same required objects (e.g. the transforms) but with the new training data (in my case a 17 x 1
tensor), and then to load state from the old model, setting the hyper parameters properly. Am I interpreting all of this correctly?
Thanks!
Edit: never mind, what I just asked about does appear to work regardless!
Yep, just repeat model = botorch.models.SingleTaskGP( train_X=train_x, train_Y=train_y, input_transform=Normalize(1, transform_on_eval=True) )
with updated training data and use the bit from above comment to transfer over the hyper-parameters.
@saitcakmak the problem though is that when initializing using what you show above, the correct hyper parameters are set for the transforms, but not the length scales. When you do load_state_dict
it actually sets the old parameters for the input and output transforms (i.e., it sets the right length scales but the wrong transform parameters). This also seems to mess something up with the posterior. Even manually setting model.input_transform._buffer
to the old ordered dict from the previous model doesn't work properly. The only way I can seem to fix this is by retraining.
Not entirely sure how to fix this on my end, seems pretty complicated since the posterior
method somehow also contains the transforms. I'm not following it so well.
TL;DR, can I set the model length scales without overwriting the transform parameters, without recalling train?
Also, an aside, does calling train_gpytorch_mll
on an already-trained model use the initial length scales and whatnot as initial guesses? Thus perhaps speeding up the training?
If you do
new_model = botorch.models.SingleTaskGP(
train_X=train_x, train_Y=train_y, input_transform=Normalize(1, transform_on_eval=True)
)
new_model.load_state_dict(old_model.state_dict())
It will update all buffers & parameters of the GP and its submodules with the corresponding values from the old_model
. If you then train this model with fit_gpytorch_mll
, it should use these as the starting values in the training. Since the model will be called in train mode during training, the input transform buffers (e.g. for Normalize
) should also get updated. As you noticed, this doesn't work with the outcome transforms, since they're only called in train
mode while initializing the model. One way to get around this is to exclude the outcome transforms from the state dict loading.
new_model.load_state_dict(
{k:v for k, v in old_model.state_dict().items() if "outcome_transform" not in k},
strict=False, # needed since the state dict is now missing certain keys.
)
can I set the model length scales without overwriting the transform parameters, without recalling train?
The above bit should work for this. Though, if you don't want to retrain / call train
on the model, the original bug should not be an issue, so you can also just use condition_on_observations
. That bug will only happen if you call model.train()
after condition_on_observations
.
Also, an aside, does calling train_gpytorch_mll on an already-trained model use the initial length scales and whatnot as initial guesses? Thus perhaps speeding up the training?
Yes, for the first model training attempt, it should use the model parameters as the initial values. If the first attempt fails and it has to retry, then it will randomly sample from the priors. Intuitively, this should speed things up.
@saitcakmak Your suggestion does work, but I'd like to bring another set of weird behaviors to your attention:
print(train_x)
"""
tensor([[0.0000],
[0.7143],
[1.4286],
[2.1429],
[2.8571]])
"""
print(train_y)
"""
tensor([-1.6720, 45.1938, 72.6386, 93.8865, 79.5389])
"""
model = botorch.models.SingleTaskGP(
train_X=train_x,
train_Y=train_y,
input_transform=Normalize(1),
outcome_transform=Standardize(1)
)
model.train_inputs[0]
"""
tensor([[0.0000],
[0.7143],
[1.4286],
[2.1429],
[2.8571]])
"""
model.train_targets
"""
tensor([-1.5798, -0.3373, 0.3903, 0.9536, 0.5732])
"""
For some reason, the inputs are not being scaled immediately like the targets. This is making retrieving the current model's training data from the model itself really challenging, especially after reconditioning. Any thoughts?
EDIT: I think I've nailed down the source of this. Looks like the training data inputs behavior changes with whether or not model
is in train()
or eval()
mode, whereas the targets are independent! Is this intended?
I think I've nailed down the source of this. Looks like the training data inputs behavior changes with whether or not model is in train() or eval() mode, whereas the targets are independent! Is this intended?
Looks like you figured out. We transform the outcomes before passing them to GPyTorch, so train_targets
will always show the transformed outcomes. For train_inputs
, the story is a bit more complicated. Currently, it will be untransformed if the model is in train
mode, and transformed if the model is in eval
mode (with the originals available at _original_train_inputs
). This was a hack to make sure we can apply some of the input transforms a bit more selectively to train / test inputs. I'll clean all of that up in #1372 (train_inputs
will always be untransformed, and input transforms will be handled in GPyTorch), though I have some other things I need to work on before I get to that.
Ahh so that's what _original_train_inputs
is for. I never noticed that it doesn't change during transforms. Ok all of this sounds good. Looking forward to the botorch+gpytorch update!
@saitcakmak funny that I come back to this exactly a year later. Have there been any updates on merging #1372? 😁
Looks like the GPyTorch side of the changes is ready to merge but just never was. Thanks!
Hi @matthewcarbone , refactoring input transforms is still in progress. #1372 can't be merged in as-is since there are a good number of other changes that need to be made to ensure that such a large change works smoothly.
@esantorella ok no problem. I'm happy to help contribute somehow if there's any opportunity. Thanks! 👍