Ax icon indicating copy to clipboard operation
Ax copied to clipboard

`AxClient.get_best_trial()` produces the wrong best trial index with `use_model_predictions=True`.

Open leandrobbraga opened this issue 2 years ago • 8 comments

I was working with the Service API in a constrained single-objective set and the get_best_trial() function returns me the parameters from a trial, but with the wrong index.

This is the dataframe containing all the trials: image (11)

The output of get_best_trial() image (12)

This second image says that the parameters belong to the 29th trial, which is not true, this is actually the 4th trial.

When I pass the parameter use_model_predictions=False it gives the correct index (10).

leandrobbraga avatar May 24 '23 14:05 leandrobbraga

Thanks for catching this, this is indeed a bug on our end. @saitcakmak could you take a look at fixing this?

mpolson64 avatar May 24 '23 15:05 mpolson64

If you guys are ok with it, I could try to solve this issue.

leandrobbraga avatar Jun 08 '23 17:06 leandrobbraga

I did some work today and managed to write a failing test for it:

    @patch(
        f"{get_best_parameters_from_model_predictions_with_trial_index.__module__}"
        + ".assess_model_fit",
        wraps=assess_model_fit,
        return_value=AssessModelFitResult(
            good_fit_metrics_to_fisher_score={"x": 1},
            bad_fit_metrics_to_fisher_score={},
        ),
    )
    def test_get_best_point_with_model_prediction(
        self,
        mock_assess_model_fit,
    ) -> None:
        ax_client = AxClient()
        ax_client.create_experiment(
            name="test_experiment",
            parameters=[
                {
                    "name": "x",
                    "type": "range",
                    "bounds": [1.0, 10.0],
                },
            ],
            objectives={"y": ObjectiveProperties(minimize=True)},
            is_test=True,
            choose_generation_strategy_kwargs={"num_initialization_trials": 4},
        )

        params, idx = ax_client.get_next_trial()
        ax_client.complete_trial(idx, raw_data={"y": 1})

        for i in range(1, 5):
            ax_client.get_next_trial()
            ax_client.complete_trial(i, raw_data={"y": i})

        # ax_client.get_next_trial()
        best_index, best_params, _ = ax_client.get_best_trial()
        self.assertEqual(best_index, idx)
        self.assertEqual(best_params, params)
        mock_assess_model_fit.assert_called()

I know that the issue is in the ax/service/utils/best_point.py get_best_parameters_from_model_predictions_with_trial_index method. If I understand correctly it's using the RunGenerator index, not the actual best param index.

leandrobbraga avatar Jun 08 '23 20:06 leandrobbraga

If I understand correctly it's using the RunGenerator index, not the actual best param index.

Yes, that's exactly the issue. I had discussed this with Miles but we forgot to update it here. It returns the index of last GeneratorRun that has a model that can be used to evaluate the arms to find out the best performing one. The index has nothing to do with the actual predicted best arm.

saitcakmak avatar Jun 08 '23 20:06 saitcakmak

Is this issue still "in progress"? Is there a work-around to get the best trial index?

noppelmax avatar Jan 11 '24 07:01 noppelmax

Hi @noppelmax. I don't think anyone has been working on this. The issue is with the model prediction based best trial index. If you call AxClient.get_best_trial(use_model_predictions=False), it will use the raw observations to find the best point and return the correct trial index.

saitcakmak avatar Jan 16 '24 19:01 saitcakmak

Looks like @saitcakmak provided a workaround here, and this has been inactive for quite some time. @leandrobbraga, please reopen if you'd like to continue the discussion! We likely won't see further activity on a closed issue.

lena-kashtelyan avatar Jul 31 '24 04:07 lena-kashtelyan

@lena-kashtelyan This is still a bug with use_model_predictions=True, which is the default. Let's keep it open for tracking. If we get around to doing a clean up / rewrite of best point utils, this should be fixed in the process as well.

saitcakmak avatar Jul 31 '24 04:07 saitcakmak