Ax icon indicating copy to clipboard operation
Ax copied to clipboard

EHVI & NEHVI break with more than 7 objectives

Open ronald-jaepel opened this issue 2 months ago • 4 comments

Hello Ax Team,

when running EHVI or NEHVI with more than 7 objectives, we get an error during the evaluation of the objective function.

Here's an MRE:

import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties

N_OBJECTIVES = 8

ax_client = AxClient()
ax_client.create_experiment(
    name="test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            "bounds": [-5.0, 10.0],
            "value_type": "float",
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 10.0],
            "value_type": "float",
        },
    ],
    objectives={
        f"Objective_{i}": ObjectiveProperties(minimize=True, threshold=1) for i in range(N_OBJECTIVES)
    },
)


def objective_function():
    res = {
        f"Objective_{i}": np.random.rand() for i in range(N_OBJECTIVES)
    }
    return res


for _ in range(15):
    parameters, trial_index = ax_client.get_next_trial()
    ax_client.complete_trial(trial_index=trial_index, raw_data=objective_function())

and here's the full traceback:

  File "...\ax\service\ax_client.py", line 531, in get_next_trial
    generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\ax\service\ax_client.py", line 1763, in _gen_new_generator_run
    return not_none(self.generation_strategy).gen(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\ax\modelbridge\generation_strategy.py", line 478, in gen
    return self._gen_multiple(
           ^^^^^^^^^^^^^^^^^^^
  File "...\ax\modelbridge\generation_strategy.py", line 675, in _gen_multiple
    generator_run = self._curr.gen(
                    ^^^^^^^^^^^^^^^
  File "...\ax\modelbridge\generation_node.py", line 737, in gen
    gr = super().gen(
         ^^^^^^^^^^^^
  File "...\ax\modelbridge\generation_node.py", line 307, in gen
    generator_run = model_spec.gen(
                    ^^^^^^^^^^^^^^^
  File "...\ax\modelbridge\model_spec.py", line 219, in gen
    return fitted_model.gen(**model_gen_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\ax\modelbridge\base.py", line 784, in gen
    gen_results = self._gen(
                  ^^^^^^^^^^
  File "...\ax\modelbridge\torch.py", line 690, in _gen
    gen_results = not_none(self.model).gen(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\ax\models\torch\botorch_modular\model.py", line 428, in gen
    candidates, expected_acquisition_value = acqf.optimize(
                                             ^^^^^^^^^^^^^^
  File "...\ax\models\torch\botorch_modular\acquisition.py", line 439, in optimize
    return optimize_acqf(
           ^^^^^^^^^^^^^^
  File "...\botorch\optim\optimize.py", line 563, in optimize_acqf
    return _optimize_acqf(opt_acqf_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\optim\optimize.py", line 584, in _optimize_acqf
    return _optimize_acqf_batch(opt_inputs=opt_inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\optim\optimize.py", line 274, in _optimize_acqf_batch
    batch_initial_conditions = opt_inputs.get_ic_generator()(
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\optim\initializers.py", line 417, in gen_batch_initial_conditions
    Y_rnd_curr = acq_function(
                 ^^^^^^^^^^^^^
  File "...\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\torch\nn\modules\module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\utils\transforms.py", line 305, in decorated
    return method(cls, X, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\utils\transforms.py", line 259, in decorated
    output = method(acqf, X, *args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\acquisition\multi_objective\logei.py", line 468, in forward
    nehvi = self._compute_log_qehvi(samples=samples, X=X)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\acquisition\multi_objective\logei.py", line 267, in _compute_log_qehvi
    return logmeanexp(logsumexp(log_areas_per_segment, dim=-1), dim=0)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\utils\safe_math.py", line 146, in logsumexp
    return _inf_max_helper(torch.logsumexp, x=x, dim=dim, keepdim=keepdim)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "...\botorch\utils\safe_math.py", line 170, in _inf_max_helper
    M = x.amax(dim=dim, keepdim=True)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: amax(): Expected reduction dim -1 to have non-zero size.

ronald-jaepel avatar Apr 21 '24 19:04 ronald-jaepel