Ax icon indicating copy to clipboard operation
Ax copied to clipboard

Implementing `qGIBBON` (`qLowerBoundMaxValueEntropy`) with Ax

Open NeptuneProjects opened this issue 1 year ago • 4 comments

Any chance an upcoming update will ship with the GIBBON acquisition function input constructor? 🥲 I'm trying to implement the qLowerBoundMaxValueEntropy acquisition function by following along with the tutorial since it's not registered for use in Ax. I've looked at the examples in botorch.acquisition.input_constructors and botorch.acquisition.max_value_entropy_search, but after tinkering with several permutations of arguments and unsuccessfully tracing inheritances across the MaxValueBase and DiscreteMaxValueBase base classes, I'm having trouble identifying the correct pattern for mapping the arguments.

An example of my futile attempt:

from ax.modelbridge.registry import Models
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
from botorch.acquisition.input_constructors import (
    MaybeDict,
    acqf_input_constructor,
    _construct_inputs_mc_base,
    _get_dataset_field,
)
from botorch.acquisition.max_value_entropy_search import qLowerBoundMaxValueEntropy
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.model import Model
from botorch.utils.datasets import SupervisedDataset
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood


class qGIBBON(qLowerBoundMaxValueEntropy):
    ...


@acqf_input_constructor(qGIBBON)
def construct_inputs_qGIBBON(
    model: Model,
    training_data: MaybeDict[SupervisedDataset],
    bounds: list[tuple[float, float]],
    objective: Optional[MCAcquisitionObjective] = None,
    posterior_transform: Optional[PosteriorTransform] = None,
    candidate_size: int = 1000,
    **kwargs: Any,
) -> dict[str, Any]:
    r"""Construct kwargs for `qMaxValueEntropy` constructor."""
    inputs_mc = _construct_inputs_mc_base(
        model=model,
        objective=objective,
    )

    X = _get_dataset_field(training_data, "X", first_only=True)
    _kw = {"device": X.device, "dtype": X.dtype}
    _rvs = torch.rand(candidate_size, len(bounds), **_kw)
    _bounds = torch.tensor(bounds, **_kw).transpose(0, 1)
    return {
        **inputs_mc,
        "candidate_set": _bounds[0] + (_bounds[1] - _bounds[0]) * _rvs,
        "maximize": kwargs.get("maximize", True),
    }


@optimizer_argparse.register(qGIBBON)
def _argparse_my_acqf(
    acqf: qGIBBON, sequential: bool = True
) -> dict:
    return {
        "use_gumbel": True
    }

I'm using the acquisition function with the generation strategy/service client API:

gs = GenerationStrategy(
        steps=[
            GenerationStep(
                model=Models.SOBOL,
                num_trials=64,
                max_parallelism=64,
            ),
            GenerationStep(
                model=Models.BOTORCH_MODULAR,
                num_trials=500 - 64,
                max_parallelism=None,
                model_kwargs={
                    "surrogate": Surrogate(
                        botorch_model_class=SingleTaskGP,
                        mll_class=ExactMarginalLogLikelihood,
                    ),
                    "botorch_acqf_class": qGIBBON,
                    "torch_device": device,
                },
                model_gen_kwargs={
                    "model_gen_options": {
                        "optimizer_kwargs": {
                            "num_restarts": 120,
                            "raw_samples": 4096,
                        }
                    }
                },
            ),
        ]
    )

which results in the following error, unsurprisingly since construct_inputs_qGIBBON is just copied & pasted from construct_inputs_qMES:

Traceback (most recent call last):
  File "/lib/python3.10/site-packages/ax/service/ax_client.py", line 631, in get_next_trials
    params, trial_index = self.get_next_trial(ttl_seconds=ttl_seconds)
  File "/lib/python3.10/site-packages/ax/utils/common/executils.py", line 161, in actual_wrapper
    return func(*args, **kwargs)
  File "/lib/python3.10/site-packages/ax/service/ax_client.py", line 545, in get_next_trial
    generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds
  File "/lib/python3.10/site-packages/ax/service/ax_client.py", line 1734, in _gen_new_generator_run
    return not_none(self.generation_strategy).gen(
  File "/lib/python3.10/site-packages/ax/modelbridge/generation_strategy.py", line 334, in gen
    return self._gen_multiple(
  File "/lib/python3.10/site-packages/ax/modelbridge/generation_strategy.py", line 479, in _gen_multiple
    generator_run = self._curr.gen(
  File "/lib/python3.10/site-packages/ax/modelbridge/generation_node.py", line 408, in gen
    gr = super().gen(
  File "/lib/python3.10/site-packages/ax/modelbridge/generation_node.py", line 193, in gen
    generator_run = model_spec.gen(
  File "/lib/python3.10/site-packages/ax/modelbridge/model_spec.py", line 225, in gen
    return fitted_model.gen(**model_gen_kwargs)
  File "/lib/python3.10/site-packages/ax/modelbridge/base.py", line 805, in gen
    gen_results = self._gen(
  File "/lib/python3.10/site-packages/ax/modelbridge/torch.py", line 611, in _gen
    gen_results = not_none(self.model).gen(
  File "/lib/python3.10/site-packages/ax/models/torch/botorch_modular/model.py", line 451, in gen
    acqf = self._instantiate_acquisition(
  File "/lib/python3.10/site-packages/ax/models/torch/botorch_modular/model.py", line 664, in _instantiate_acquisition
    return self.acquisition_class(
  File "/lib/python3.10/site-packages/ax/models/torch/botorch_modular/acquisition.py", line 312, in __init__
    self.acqf = botorch_acqf_class(**acqf_inputs)  # pyre-ignore [45]
TypeError: DiscreteMaxValueBase.__init__() got an unexpected keyword argument 'objective'

NeptuneProjects avatar Oct 10 '23 00:10 NeptuneProjects

Hi @NeptuneProjects. Thanks for the question. It looks like _construct_inputs_mc_base() is from an earlier version of botorch, so I'm not exactly sure what it does, but I repro'd without it and it looks like qLowerBoundMaxValueEntropy inherits from DiscreteMaxValueBase which actually doesn't take an objective in init, so you need to make sure "objective" is not a key in the dict returned by construct_inputs_qGIBBON(). The next problem I run into though is that the train_inputs arg is required, which may be because I'm not using _construct_inputs_mc_base().

cc @Balandat

danielcohenlive avatar Oct 10 '23 15:10 danielcohenlive

To follow up @NeptuneProjects , can you confirm what is in the inputs_mc dict, and also try calling del inputs_mc["objective"] to see if that solves the problem? It would also be useful to know what version of botorch and Ax you're using.

danielcohenlive avatar Oct 13 '23 14:10 danielcohenlive

+1 to understanding what versions of Ax/BoTorch you're currently using. On trunk, the input constructor for qMaxValueEntropy (which also subclasses DiscreteMaxValueBase should be a good starting point for q qLowerBoundMaxValueEntropy (no need to define a qGibbon class): https://github.com/pytorch/botorch/blob/841392063a5289377cb3ee46c62827003ff2bc07/botorch/acquisition/input_constructors.py#L1022-L1041 - in fact, unless you want to expose some of the optional args, you should be able to just use that verbatim.

Balandat avatar Oct 15 '23 01:10 Balandat

I have been using Ax 0.3.1 and BoTorch 0.8.3. I should have upgraded before I asked the question - I will do that now, play with your suggestions, and get back to you.

NeptuneProjects avatar Oct 18 '23 15:10 NeptuneProjects

@NeptuneProjects I'm going to assume this solved your problem and close the issue, but feel free to reopen if you have further issues.

danielcohenlive avatar Jul 08 '24 21:07 danielcohenlive