Ax
Ax copied to clipboard
Implementing `qGIBBON` (`qLowerBoundMaxValueEntropy`) with Ax
Any chance an upcoming update will ship with the GIBBON acquisition function input constructor? 🥲 I'm trying to implement the qLowerBoundMaxValueEntropy
acquisition function by following along with the tutorial since it's not registered for use in Ax. I've looked at the examples in botorch.acquisition.input_constructors
and botorch.acquisition.max_value_entropy_search
, but after tinkering with several permutations of arguments and unsuccessfully tracing inheritances across the MaxValueBase
and DiscreteMaxValueBase
base classes, I'm having trouble identifying the correct pattern for mapping the arguments.
An example of my futile attempt:
from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
from botorch.acquisition.input_constructors import (
MaybeDict,
acqf_input_constructor,
_construct_inputs_mc_base,
_get_dataset_field,
)
from botorch.acquisition.max_value_entropy_search import qLowerBoundMaxValueEntropy
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
class qGIBBON(qLowerBoundMaxValueEntropy):
...
@acqf_input_constructor(qGIBBON)
def construct_inputs_qGIBBON(
model: Model,
training_data: MaybeDict[SupervisedDataset],
bounds: list[tuple[float, float]],
objective: Optional[MCAcquisitionObjective] = None,
posterior_transform: Optional[PosteriorTransform] = None,
candidate_size: int = 1000,
**kwargs: Any,
) -> dict[str, Any]:
r"""Construct kwargs for `qMaxValueEntropy` constructor."""
inputs_mc = _construct_inputs_mc_base(
model=model,
objective=objective,
)
X = _get_dataset_field(training_data, "X", first_only=True)
_kw = {"device": X.device, "dtype": X.dtype}
_rvs = torch.rand(candidate_size, len(bounds), **_kw)
_bounds = torch.tensor(bounds, **_kw).transpose(0, 1)
return {
**inputs_mc,
"candidate_set": _bounds[0] + (_bounds[1] - _bounds[0]) * _rvs,
"maximize": kwargs.get("maximize", True),
}
@optimizer_argparse.register(qGIBBON)
def _argparse_my_acqf(
acqf: qGIBBON, sequential: bool = True
) -> dict:
return {
"use_gumbel": True
}
I'm using the acquisition function with the generation strategy/service client API:
gs = GenerationStrategy(
steps=[
GenerationStep(
model=Models.SOBOL,
num_trials=64,
max_parallelism=64,
),
GenerationStep(
model=Models.BOTORCH_MODULAR,
num_trials=500 - 64,
max_parallelism=None,
model_kwargs={
"surrogate": Surrogate(
botorch_model_class=SingleTaskGP,
mll_class=ExactMarginalLogLikelihood,
),
"botorch_acqf_class": qGIBBON,
"torch_device": device,
},
model_gen_kwargs={
"model_gen_options": {
"optimizer_kwargs": {
"num_restarts": 120,
"raw_samples": 4096,
}
}
},
),
]
)
which results in the following error, unsurprisingly since construct_inputs_qGIBBON
is just copied & pasted from construct_inputs_qMES
:
Traceback (most recent call last):
File "/lib/python3.10/site-packages/ax/service/ax_client.py", line 631, in get_next_trials
params, trial_index = self.get_next_trial(ttl_seconds=ttl_seconds)
File "/lib/python3.10/site-packages/ax/utils/common/executils.py", line 161, in actual_wrapper
return func(*args, **kwargs)
File "/lib/python3.10/site-packages/ax/service/ax_client.py", line 545, in get_next_trial
generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds
File "/lib/python3.10/site-packages/ax/service/ax_client.py", line 1734, in _gen_new_generator_run
return not_none(self.generation_strategy).gen(
File "/lib/python3.10/site-packages/ax/modelbridge/generation_strategy.py", line 334, in gen
return self._gen_multiple(
File "/lib/python3.10/site-packages/ax/modelbridge/generation_strategy.py", line 479, in _gen_multiple
generator_run = self._curr.gen(
File "/lib/python3.10/site-packages/ax/modelbridge/generation_node.py", line 408, in gen
gr = super().gen(
File "/lib/python3.10/site-packages/ax/modelbridge/generation_node.py", line 193, in gen
generator_run = model_spec.gen(
File "/lib/python3.10/site-packages/ax/modelbridge/model_spec.py", line 225, in gen
return fitted_model.gen(**model_gen_kwargs)
File "/lib/python3.10/site-packages/ax/modelbridge/base.py", line 805, in gen
gen_results = self._gen(
File "/lib/python3.10/site-packages/ax/modelbridge/torch.py", line 611, in _gen
gen_results = not_none(self.model).gen(
File "/lib/python3.10/site-packages/ax/models/torch/botorch_modular/model.py", line 451, in gen
acqf = self._instantiate_acquisition(
File "/lib/python3.10/site-packages/ax/models/torch/botorch_modular/model.py", line 664, in _instantiate_acquisition
return self.acquisition_class(
File "/lib/python3.10/site-packages/ax/models/torch/botorch_modular/acquisition.py", line 312, in __init__
self.acqf = botorch_acqf_class(**acqf_inputs) # pyre-ignore [45]
TypeError: DiscreteMaxValueBase.__init__() got an unexpected keyword argument 'objective'
Hi @NeptuneProjects. Thanks for the question. It looks like _construct_inputs_mc_base()
is from an earlier version of botorch, so I'm not exactly sure what it does, but I repro'd without it and it looks like qLowerBoundMaxValueEntropy
inherits from DiscreteMaxValueBase
which actually doesn't take an objective
in init, so you need to make sure "objective" is not a key in the dict returned by construct_inputs_qGIBBON()
. The next problem I run into though is that the train_inputs arg is required, which may be because I'm not using _construct_inputs_mc_base()
.
cc @Balandat
To follow up @NeptuneProjects , can you confirm what is in the inputs_mc
dict, and also try calling del inputs_mc["objective"]
to see if that solves the problem? It would also be useful to know what version of botorch and Ax you're using.
+1 to understanding what versions of Ax/BoTorch you're currently using. On trunk, the input constructor for qMaxValueEntropy
(which also subclasses DiscreteMaxValueBase
should be a good starting point for q qLowerBoundMaxValueEntropy
(no need to define a qGibbon
class): https://github.com/pytorch/botorch/blob/841392063a5289377cb3ee46c62827003ff2bc07/botorch/acquisition/input_constructors.py#L1022-L1041 - in fact, unless you want to expose some of the optional args, you should be able to just use that verbatim.
I have been using Ax 0.3.1 and BoTorch 0.8.3. I should have upgraded before I asked the question - I will do that now, play with your suggestions, and get back to you.
@NeptuneProjects I'm going to assume this solved your problem and close the issue, but feel free to reopen if you have further issues.