Ax
Ax copied to clipboard
EHVI & NEHVI break with more than 7 objectives
Hello Ax Team,
when running EHVI or NEHVI with more than 7 objectives, we get an error during the evaluation of the objective function.
Here's an MRE:
import numpy as np
from ax.service.ax_client import AxClient, ObjectiveProperties
N_OBJECTIVES = 8
ax_client = AxClient()
ax_client.create_experiment(
name="test_experiment",
parameters=[
{
"name": "x1",
"type": "range",
"bounds": [-5.0, 10.0],
"value_type": "float",
},
{
"name": "x2",
"type": "range",
"bounds": [0.0, 10.0],
"value_type": "float",
},
],
objectives={
f"Objective_{i}": ObjectiveProperties(minimize=True, threshold=1) for i in range(N_OBJECTIVES)
},
)
def objective_function():
res = {
f"Objective_{i}": np.random.rand() for i in range(N_OBJECTIVES)
}
return res
for _ in range(15):
parameters, trial_index = ax_client.get_next_trial()
ax_client.complete_trial(trial_index=trial_index, raw_data=objective_function())
and here's the full traceback:
File "...\ax\service\ax_client.py", line 531, in get_next_trial
generator_run=self._gen_new_generator_run(), ttl_seconds=ttl_seconds
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\ax\service\ax_client.py", line 1763, in _gen_new_generator_run
return not_none(self.generation_strategy).gen(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\ax\modelbridge\generation_strategy.py", line 478, in gen
return self._gen_multiple(
^^^^^^^^^^^^^^^^^^^
File "...\ax\modelbridge\generation_strategy.py", line 675, in _gen_multiple
generator_run = self._curr.gen(
^^^^^^^^^^^^^^^
File "...\ax\modelbridge\generation_node.py", line 737, in gen
gr = super().gen(
^^^^^^^^^^^^
File "...\ax\modelbridge\generation_node.py", line 307, in gen
generator_run = model_spec.gen(
^^^^^^^^^^^^^^^
File "...\ax\modelbridge\model_spec.py", line 219, in gen
return fitted_model.gen(**model_gen_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\ax\modelbridge\base.py", line 784, in gen
gen_results = self._gen(
^^^^^^^^^^
File "...\ax\modelbridge\torch.py", line 690, in _gen
gen_results = not_none(self.model).gen(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\ax\models\torch\botorch_modular\model.py", line 428, in gen
candidates, expected_acquisition_value = acqf.optimize(
^^^^^^^^^^^^^^
File "...\ax\models\torch\botorch_modular\acquisition.py", line 439, in optimize
return optimize_acqf(
^^^^^^^^^^^^^^
File "...\botorch\optim\optimize.py", line 563, in optimize_acqf
return _optimize_acqf(opt_acqf_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\optim\optimize.py", line 584, in _optimize_acqf
return _optimize_acqf_batch(opt_inputs=opt_inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\optim\optimize.py", line 274, in _optimize_acqf_batch
batch_initial_conditions = opt_inputs.get_ic_generator()(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\optim\initializers.py", line 417, in gen_batch_initial_conditions
Y_rnd_curr = acq_function(
^^^^^^^^^^^^^
File "...\torch\nn\modules\module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\torch\nn\modules\module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\utils\transforms.py", line 305, in decorated
return method(cls, X, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\utils\transforms.py", line 259, in decorated
output = method(acqf, X, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\acquisition\multi_objective\logei.py", line 468, in forward
nehvi = self._compute_log_qehvi(samples=samples, X=X)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\acquisition\multi_objective\logei.py", line 267, in _compute_log_qehvi
return logmeanexp(logsumexp(log_areas_per_segment, dim=-1), dim=0)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\utils\safe_math.py", line 146, in logsumexp
return _inf_max_helper(torch.logsumexp, x=x, dim=dim, keepdim=keepdim)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "...\botorch\utils\safe_math.py", line 170, in _inf_max_helper
M = x.amax(dim=dim, keepdim=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: amax(): Expected reduction dim -1 to have non-zero size.