Ax
Ax copied to clipboard
`AxClient.get_best_trial()` produces the wrong best trial index with `use_model_predictions=True`.
I was working with the Service API in a constrained single-objective set and the get_best_trial() function returns me the parameters from a trial, but with the wrong index.
This is the dataframe containing all the trials:
The output of get_best_trial()
This second image says that the parameters belong to the 29th trial, which is not true, this is actually the 4th trial.
When I pass the parameter use_model_predictions=False it gives the correct index (10).
Thanks for catching this, this is indeed a bug on our end. @saitcakmak could you take a look at fixing this?
If you guys are ok with it, I could try to solve this issue.
I did some work today and managed to write a failing test for it:
@patch(
f"{get_best_parameters_from_model_predictions_with_trial_index.__module__}"
+ ".assess_model_fit",
wraps=assess_model_fit,
return_value=AssessModelFitResult(
good_fit_metrics_to_fisher_score={"x": 1},
bad_fit_metrics_to_fisher_score={},
),
)
def test_get_best_point_with_model_prediction(
self,
mock_assess_model_fit,
) -> None:
ax_client = AxClient()
ax_client.create_experiment(
name="test_experiment",
parameters=[
{
"name": "x",
"type": "range",
"bounds": [1.0, 10.0],
},
],
objectives={"y": ObjectiveProperties(minimize=True)},
is_test=True,
choose_generation_strategy_kwargs={"num_initialization_trials": 4},
)
params, idx = ax_client.get_next_trial()
ax_client.complete_trial(idx, raw_data={"y": 1})
for i in range(1, 5):
ax_client.get_next_trial()
ax_client.complete_trial(i, raw_data={"y": i})
# ax_client.get_next_trial()
best_index, best_params, _ = ax_client.get_best_trial()
self.assertEqual(best_index, idx)
self.assertEqual(best_params, params)
mock_assess_model_fit.assert_called()
I know that the issue is in the ax/service/utils/best_point.py get_best_parameters_from_model_predictions_with_trial_index method. If I understand correctly it's using the RunGenerator index, not the actual best param index.
If I understand correctly it's using the RunGenerator index, not the actual best param index.
Yes, that's exactly the issue. I had discussed this with Miles but we forgot to update it here. It returns the index of last GeneratorRun that has a model that can be used to evaluate the arms to find out the best performing one. The index has nothing to do with the actual predicted best arm.
Is this issue still "in progress"? Is there a work-around to get the best trial index?
Hi @noppelmax. I don't think anyone has been working on this. The issue is with the model prediction based best trial index. If you call AxClient.get_best_trial(use_model_predictions=False), it will use the raw observations to find the best point and return the correct trial index.
Looks like @saitcakmak provided a workaround here, and this has been inactive for quite some time. @leandrobbraga, please reopen if you'd like to continue the discussion! We likely won't see further activity on a closed issue.
@lena-kashtelyan This is still a bug with use_model_predictions=True, which is the default. Let's keep it open for tracking. If we get around to doing a clean up / rewrite of best point utils, this should be fixed in the process as well.