botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Bug] fit_gpytorch_mll gives backward pass runtime exception on second model fit attempt with HeteroskedasticSingleTaskGP

Open Tim-Infl opened this issue 8 months ago • 11 comments

🐛 Bug

fit_gpytorch_mll gives a runtime error related to trying to traverse graph backward for second time on specific data in this case using a HeteroskedasticSingleTaskGP (not sure if this bug is also possible with other GPs).

To reproduce

** Code snippet to reproduce **

import torch
from botorch.fit import fit_gpytorch_mll
from botorch.models.gp_regression import HeteroskedasticSingleTaskGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize
from gpytorch.mlls import ExactMarginalLogLikelihood
x = [[0.052, 0.4, 124500000.0], [0.05761923313140868, 0.34092908203601835, 123245033.62178802], [0.02496520895510912, 0.564002669788897, 141507782.5821936], [0.03236924238502979, 0.10900852866470814, 116506682.57847428], [0.045027089267969125, 0.38581262519583104, 134769803.33216488], [0.04303258828818798, 0.20364383878186346, 138426915.16503692], [0.03944164723157882, 0.4200226565822959, 126660296.03220522], [0.026954370103776455, 0.2461717002093792, 131571062.21653521], [0.050551558472216124, 0.5299147198908031, 119804986.17514968], [0.05266778387129306, 0.15474029993638397, 132583202.42166519], [0.02970170006155968, 0.37108883671462534, 120905024.66075122], [0.037339153699576855, 0.2948328586295247, 140845334.48331058], [0.040311199314892285, 0.5786066324450075, 129166555.40466309], [0.047925519607961174, 0.2607450204901397, 117779762.43756711], [0.03464633349329233, 0.4838493674993515, 136072158.99974108], [0.022082343921065332, 0.18895020466297865, 125924335.01034975], [0.055365028381347645, 0.4657240201719105, 144216416.46884382], [0.056770838238298885, 0.16977797718718649, 122108987.12277412], [0.02067440327256918, 0.454394800029695, 132901076.1808604], [0.03324206084012985, 0.28065053597092626, 128846995.0389117], [0.04933253459632396, 0.49591188682243226, 139639398.3066082], [0.05240863501499841, 0.4056738189544114, 125318067.51154271], [0.05159197088597474, 0.3897060799790354, 123950711.44215755], [0.05151763477064899, 0.4390626950940063, 123950535.56188977], [0.0508759687423927, 0.37412064627267866, 124095160.38373113], [0.05210804609212111, 0.3751506052650311, 124433898.90715164], [0.052269320861764446, 0.34567018195966137, 123408419.40603946], [0.0508261035607972, 0.3811022974936308, 124472247.23175177], [0.05284520682346304, 0.3867088942948218, 124165296.76877046], [0.05208352835147538, 0.3457817061945064, 124553860.92886323], [0.05163773346255494, 0.3772675285570827, 124215683.8253139], [0.051188540407382646, 0.39203101090933506, 124129712.0694072], [0.04917095282864485, 0.38048079671593066, 124552463.35826813], [0.05170216626141591, 0.3890676227235632, 124395165.88114208], [0.05160025600830183, 0.38668564159236996, 124099926.55739184], [0.05370358849719392, 0.3556720936729202, 124680923.37365812], [0.05302638420478325, 0.36499738861205294, 124648968.28867507], [0.05243620330969558, 0.3639449683322785, 124269220.82900713], [0.053101485209401424, 0.3328020578813413, 124551760.53568736], [0.05460006467330157, 0.3166178243656474, 124941270.34209757], [0.054639360107865784, 0.27058576696481573, 124860886.689276]]
y = [[0.07915318230852243], [0.042475728155339794], [0.011192017259978421], [0.011596548004314991], [0.010922330097087376], [0.010382955771305283], [0.010517799352750806], [0.009978425026968713], [0.012135922330097084], [0.010113268608414236], [0.010248112189859758], [0.010113268608414236], [0.010787486515641851], [0.010517799352750806], [0.010248112189859758], [0.010113268608414236], [0.010113268608414236], [0.02467637540453074], [0.010787486515641851], [0.010922330097087376], [0.010113268608414236], [0.06499460625674247], [0.07996224379719555], [0.06863538295577158], [0.07861380798274033], [0.08279395900755154], [0.06270226537216857], [0.07025350593311787], [0.07173678532901863], [0.07483818770226566], [0.0802319309600866], [0.08778317152103592], [0.05272384034519955], [0.07133225458468205], [0.0680960086299895], [0.08063646170442317], [0.07874865156418587], [0.08414239482200678], [0.07996224379719555], [0.0860302049622441], [0.08265911542610602]]
y_std = [[0.0028213855281508378], [0.0022245261582752697], [0.0012120064968019716], [0.001233454542438619], [0.0011995366641219226], [0.0011701726643290147], [0.001178064737128823], [0.001148303988368498], [0.0012620390635098254], [0.001156042893708349], [0.0011626782061339128], [0.001156042893708349], [0.0011922268052885988], [0.001178064737128823], [0.001163429824420529], [0.0011558916484805534], [0.0011558916484805534], [0.0017600335962596214], [0.0011917867879991005], [0.001198953515490414], [0.0011558916484805534], [0.002627121625318947], [0.0028339987275969756], [0.0026693737536671105], [0.0028230581579652214], [0.002857529009561201], [0.0025919474867974277], [0.002684721301623331], [0.002707159759673002], [0.002765684513694602], [0.0028362804063563787], [0.0029301491870543017], [0.0024054051723306257], [0.0027262090910404155], [0.0026588079127803244], [0.002846740291066787], [0.002786467588404565], [0.002860769914027928], [0.0028289354755048346], [0.002875460788470168], [0.002869495979699964]]
x_t = torch.tensor(x) # Note we get the same error when dtype=torch.float64
y_t = torch.tensor(y)
y_std_t = torch.tensor(y_std)
model = HeteroskedasticSingleTaskGP(
                train_X=x_t,
                train_Y=y_t,
                train_Yvar=torch.square(y_std_t),
                input_transform=Normalize(x_t.shape[1]),
                outcome_transform=Standardize(1),
            )
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll)

** Stack trace/error message **

/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/models/gp_regression.py:298: UserWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444
  self._validate_tensor_args(X=train_X, Y=train_Y, Yvar=train_Yvar)
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/models/utils/assorted.py:174: InputDataWarning: Input data is not contained to the unit cube. Please consider min-max scaling the input data.
  warnings.warn(msg, InputDataWarning)
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/models/gp_regression.py:161: UserWarning: The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/pytorch/botorch/discussions/1444
  self._validate_tensor_args(X=transformed_X, Y=train_Y, Yvar=train_Yvar)
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/models/utils/assorted.py:202: InputDataWarning: Input data is not standardized (mean = tensor([-5.7230]), std = tensor([0.8405])). Please consider scaling the input to zero mean and unit variance.
  warnings.warn(msg, InputDataWarning)
/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/fit.py:102: OptimizationWarning: `scipy_minimize` terminated with status 3, displaying original message from `scipy.optimize.minimize`: ABNORMAL_TERMINATION_IN_LNSRCH
  warn(
Traceback (most recent call last):
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/IPython/core/interactiveshell.py", line 3550, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-16-b7091d5acb92>", line 21, in <module>
    fit_gpytorch_mll(mll)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/fit.py", line 105, in fit_gpytorch_mll
    return FitGPyTorchMLL(
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/utils/dispatcher.py", line 93, in __call__
    return func(*args, **kwargs)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/fit.py", line 252, in _fit_fallback
    optimizer(mll, closure=closure, **optimizer_kwargs)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/fit.py", line 92, in fit_gpytorch_mll_scipy
    result = scipy_minimize(
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/core.py", line 109, in scipy_minimize
    raw = minimize_with_timeout(
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/utils/timeout.py", line 80, in minimize_with_timeout
    return optimize.minimize(
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_minimize.py", line 699, in minimize
    res = _minimize_lbfgsb(fun, x0, args, jac, bounds,
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_lbfgsb_py.py", line 362, in _minimize_lbfgsb
    f, g = func_and_grad(x)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py", line 285, in fun_and_grad
    self._update_fun()
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py", line 251, in _update_fun
    self._update_fun_impl()
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py", line 155, in update_fun
    self.f = fun_wrapped(self.x)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_differentiable_functions.py", line 137, in fun_wrapped
    fx = fun(np.copy(x), *args)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_optimize.py", line 76, in __call__
    self._compute_if_needed(x, *args)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/scipy/optimize/_optimize.py", line 70, in _compute_if_needed
    fg = self.fun(x, *args)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/closures/core.py", line 160, in __call__
    value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/utils/common.py", line 52, in _handle_numerical_errors
    raise error  # pragma: nocover
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/closures/core.py", line 150, in __call__
    value_tensor, grad_tensors = self.closure(**kwargs)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/botorch/optim/closures/core.py", line 66, in __call__
    self.backward(value)
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "/home/tim_lewis/.virtualenvs/py39_client/lib/python3.9/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Expected Behavior

We expect fit_gpytorch_mll to return successfully.

System information

Please complete the following information:

  • Botorch version: 0.10.0
  • GPyTorch version: 1.11
  • PyTorch version: 2.2.2+cu121
  • OS: Ubuntu 22.04.3 LTS

Additional context

We get this exception from time to time while training our models seemingly at random. It appears to be data related. Notably, removing or modifying the last data point in the given example fixes this error. If an exception is expected in this case because there is something wrong with the data this still seems like the wrong error message to be thrown.

Tim-Infl avatar Jun 10 '24 19:06 Tim-Infl