Ax icon indicating copy to clipboard operation
Ax copied to clipboard

Support attaching data only for some metrics in Ax Service API

Open prtkm opened this issue 2 years ago • 5 comments

Hi,

I was wondering if you could help me figure out how to setup a Bayesian Optimization problem in Ax where I have multiple different objectives for the same parameter space. However, for some of the objectives in our experimental dataset, we only evaluate the objectives for some of the trials (i.e. We have more trials for some objectives than others, so the model training dataset size is different for different objectives). I would like to have one model that fits all the objectives, and then an acquisition function that suggests next trials. I think I should be able to do this with ModelListGP, but I haven't been able to find an example to do this.

Prateek

prtkm avatar Mar 29 '23 23:03 prtkm

We should already automatically be dispatching to the ModelListGP model in cases where not all objectives are observed for all trials. Have you tried this?

Balandat avatar Mar 30 '23 04:03 Balandat

I have tried something like this, where I use the service api, and loop over a data frame to attach trials. This fails because some of the trials are not observed and they are nans in the data frame. What is the correct way to attach a trial in this case?

prtkm avatar Mar 30 '23 15:03 prtkm

@lena-kashtelyan we should allow attaching partial results of trials also via the service API - who would be the right person to take a look at this?

Balandat avatar Mar 30 '23 16:03 Balandat

Sorry for the very delayed response! I'll see what we can do about allowing this setting.

lena-kashtelyan avatar Jul 25 '23 19:07 lena-kashtelyan

Hi @prtkm. AxClient currently supports attaching partial observations when completing trials. We intend to change this to add some validation, but the current behavior is to use whatever data is available and fit models to each metric using the available data. The candidate generation should work as long as there is some data for all metrics. Here's an example, modified from the Service API tutorial that has observations for l2norm only on odd trials. Note that no data is returned for this metric on the trials that we don't have data for. Returning None / NaN etc will lead to errors.

from ax.service.ax_client import AxClient, ObjectiveProperties
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import init_notebook_plotting, render

init_notebook_plotting()

ax_client = AxClient()

ax_client.create_experiment(
    name="hartmann_test_experiment",
    parameters=[
        {
            "name": "x1",
            "type": "range",
            "bounds": [0.0, 1.0],
            "value_type": "float",  # Optional, defaults to inference from type of "bounds".
            "log_scale": False,  # Optional, defaults to False.
        },
        {
            "name": "x2",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x3",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x4",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x5",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
        {
            "name": "x6",
            "type": "range",
            "bounds": [0.0, 1.0],
        },
    ],
    objectives={"hartmann6": ObjectiveProperties(minimize=True)},
    outcome_constraints=["l2norm <= 1.25"],  # Optional.
)

import numpy as np


def evaluate(trial_index, parameters):
    # Only provide l2 norm for odd trials.
    x = np.array([parameters.get(f"x{i+1}") for i in range(6)])
    # In our case, standard error is 0, since we are computing a synthetic function.
    if trial_index % 2 == 1:
        return {"hartmann6": (hartmann6(x), 0.0), "l2norm": (np.sqrt((x**2).sum()), 0.0)}
    else:
        return {"hartmann6": (hartmann6(x), 0.0)}

for i in range(15):
    parameters, trial_index = ax_client.get_next_trial()
    # Local evaluation here can be replaced with deployment to external system.
    ax_client.complete_trial(trial_index=trial_index, raw_data=evaluate(trial_index, parameters))

saitcakmak avatar Nov 30 '23 19:11 saitcakmak

Closing this as inactive

saitcakmak avatar Apr 30 '24 21:04 saitcakmak