botorch
botorch copied to clipboard
[Bug] fit_gpytorch_mll gives backward pass runtime exception on second model fit attempt with HeteroskedasticSingleTaskGP
🐛 Bug
fit_gpytorch_mll gives a runtime error related to trying to traverse graph backward for second time on specific data in this case using a HeteroskedasticSingleTaskGP (not sure if this bug is also possible with other GPs).To reproduce
** Code snippet to reproduce **
import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
x = [[0.052, 0.4, 124500000.0], [0.05761923313140868, 0.34092908203601835, 123245033.62178802], [0.02496520895510912, 0.564002669788897, 141507782.5821936], [0.03236924238502979, 0.10900852866470814, 116506682.57847428], [0.045027089267969125, 0.38581262519583104, 134769803.33216488], [0.04303258828818798, 0.20364383878186346, 138426915.16503692], [0.03944164723157882, 0.4200226565822959, 126660296.03220522], [0.026954370103776455, 0.2461717002093792, 131571062.21653521], [0.050551558472216124, 0.5299147198908031, 119804986.17514968], [0.05266778387129306, 0.15474029993638397, 132583202.42166519], [0.02970170006155968, 0.37108883671462534, 120905024.66075122], [0.037339153699576855, 0.2948328586295247, 140845334.48331058], [0.040311199314892285, 0.5786066324450075, 129166555.40466309], [0.047925519607961174, 0.2607450204901397, 117779762.43756711], [0.03464633349329233, 0.4838493674993515, 136072158.99974108], [0.022082343921065332, 0.18895020466297865, 125924335.01034975], [0.055365028381347645, 0.4657240201719105, 144216416.46884382], [0.056770838238298885, 0.16977797718718649, 122108987.12277412], [0.02067440327256918, 0.454394800029695, 132901076.1808604], [0.03324206084012985, 0.28065053597092626, 128846995.0389117], [0.04933253459632396, 0.49591188682243226, 139639398.3066082], [0.05240863501499841, 0.4056738189544114, 125318067.51154271], [0.05159197088597474, 0.3897060799790354, 123950711.44215755], [0.05151763477064899, 0.4390626950940063, 123950535.56188977], [0.0508759687423927, 0.37412064627267866, 124095160.38373113], [0.05210804609212111, 0.3751506052650311, 124433898.90715164], [0.052269320861764446, 0.34567018195966137, 123408419.40603946], [0.0508261035607972, 0.3811022974936308, 124472247.23175177], [0.05284520682346304, 0.3867088942948218, 124165296.76877046], [0.05208352835147538, 0.3457817061945064, 124553860.92886323], [0.05163773346255494, 0.3772675285570827, 124215683.8253139], [0.051188540407382646, 0.39203101090933506, 124129712.0694072], [0.04917095282864485, 0.38048079671593066, 124552463.35826813], [0.05170216626141591, 0.3890676227235632, 124395165.88114208], [0.05160025600830183, 0.38668564159236996, 124099926.55739184], [0.05370358849719392, 0.3556720936729202, 124680923.37365812], [0.05302638420478325, 0.36499738861205294, 124648968.28867507], [0.05243620330969558, 0.3639449683322785, 124269220.82900713], [0.053101485209401424, 0.3328020578813413, 124551760.53568736], [0.05460006467330157, 0.3166178243656474, 124941270.34209757], [0.054639360107865784, 0.27058576696481573, 124860886.689276]]
y = [[0.07915318230852243], [0.042475728155339794], [0.011192017259978421], [0.011596548004314991], [0.010922330097087376], [0.010382955771305283], [0.010517799352750806], [0.009978425026968713], [0.012135922330097084], [0.010113268608414236], [0.010248112189859758], [0.010113268608414236], [0.010787486515641851], [0.010517799352750806], [0.010248112189859758], [0.010113268608414236], [0.010113268608414236], [0.02467637540453074], [0.010787486515641851], [0.010922330097087376], [0.010113268608414236], [0.06499460625674247], [0.07996224379719555], [0.06863538295577158], [0.07861380798274033], [0.08279395900755154], [0.06270226537216857], [0.07025350593311787], [0.07173678532901863], [0.07483818770226566], [0.0802319309600866], [0.08778317152103592], [0.05272384034519955], [0.07133225458468205], [0.0680960086299895], [0.08063646170442317], [0.07874865156418587], [0.08414239482200678], [0.07996224379719555], [0.0860302049622441], [0.08265911542610602]]
y_std = [[0.0028213855281508378], [0.0022245261582752697], [0.0012120064968019716], [0.001233454542438619], [0.0011995366641219226], [0.0011701726643290147], [0.001178064737128823], [0.001148303988368498], [0.0012620390635098254], [0.001156042893708349], [0.0011626782061339128], [0.001156042893708349], [0.0011922268052885988], [0.001178064737128823], [0.001163429824420529], [0.0011558916484805534], [0.0011558916484805534], [0.0017600335962596214], [0.0011917867879991005], [0.001198953515490414], [0.0011558916484805534], [0.002627121625318947], [0.0028339987275969756], [0.0026693737536671105], [0.0028230581579652214], [0.002857529009561201], [0.0025919474867974277], [0.002684721301623331], [0.002707159759673002], [0.002765684513694602], [0.0028362804063563787], [0.0029301491870543017], [0.0024054051723306257], [0.0027262090910404155], [0.0026588079127803244], [0.002846740291066787], [0.002786467588404565], [0.002860769914027928], [0.0028289354755048346], [0.002875460788470168], [0.002869495979699964]]
x_t = torch.tensor(x) # Note we get the same error when dtype=torch.float64
y_t = torch.tensor(y)
y_std_t = torch.tensor(y_std)
model = HeteroskedasticSingleTaskGP(
train_X=x_t,
train_Y=y_t,
train_Yvar=torch.square(y_std_t),
input_transform=Normalize(x_t.shape[1]),
outcome_transform=Standardize(1),
)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)
** Stack trace/error message **
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/models/gp_regression.py:298: UserWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444
self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/models/utils/assorted.py:174: InputDataWarning: Input data is not contained to the unit cube. Please consider min-max scaling the input data.
warnings.warn(msg, InputDataWarning)
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/models/gp_regression.py:161: UserWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444
self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/models/utils/assorted.py:202: InputDataWarning: Input data is not standardized (mean = tensor([-5.7230]), std = tensor([0.8405])). Please consider scaling the input to zero mean and unit variance.
warnings.warn(msg, InputDataWarning)
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/fit.py:102: OptimizationWarning: `scipy_minimize` terminated with status 3, displaying original message from `scipy.optimize.minimize`: ABNORMAL_TERMINATION_IN_LNSRCH
warn(
Traceback (most recent call last):
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-16-b7091d5acb92>", line 21, in <module>
fit_gpytorch_mll(mll)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/fit.py", line 105, in fit_gpytorch_mll
return FitGPyTorchMLL(
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/utils/dispatcher.py", line 93, in __call__
return func(*args, **kwargs)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/fit.py", line 252, in _fit_fallback
optimizer(mll, closure=closure, **optimizer_kwargs)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/fit.py", line 92, in fit_gpytorch_mll_scipy
result = scipy_minimize(
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/core.py", line 109, in scipy_minimize
raw = minimize_with_timeout(
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/utils/timeout.py", line 80, in minimize_with_timeout
return optimize.minimize(
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_minimize.py", line 699, in minimize
res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_lbfgsb_py.py", line 362, in _minimize_lbfgsb
f, g = func_and_grad(x)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py", line 285, in fun_and_grad
self._update_fun()
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py", line 251, in _update_fun
self._update_fun_impl()
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py", line 155, in update_fun
self.f = fun_wrapped(self.x)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py", line 137, in fun_wrapped
fx = fun(np.copy(x), *args)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_optimize.py", line 76, in __call__
self._compute_if_needed(x, *args)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_optimize.py", line 70, in _compute_if_needed
fg = self.fun(x, *args)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/closures/core.py", line 160, in __call__
value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/utils/common.py", line 52, in _handle_numerical_errors
raise error # pragma: nocover
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/closures/core.py", line 150, in __call__
value_tensor, grad_tensors = self.closure(**kwargs)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/closures/core.py", line 66, in __call__
self.backward(value)
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/torch/_tensor.py", line 522, in backward
torch.autograd.backward(
File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/torch/autograd/__init__.py", line 266, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Expected Behavior
We expect fit_gpytorch_mll to return successfully.
System information
Please complete the following information:
- Botorch version: 0.10.0
- GPyTorch version: 1.11
- PyTorch version: 2.2.2+cu121
- OS: Ubuntu 22.04.3 LTS
Additional context
We get this exception from time to time while training our models seemingly at random. It appears to be data related. Notably, removing or modifying the last data point in the given example fixes this error. If an exception is expected in this case because there is something wrong with the data this still seems like the wrong error message to be thrown.