Ax icon indicating copy to clipboard operation
Ax copied to clipboard

Get trained model and optimization progress from a ModelBridge

Open thgngu opened this issue 3 years ago • 0 comments

Hello, I have some basic questions about using Ax. My setup is

ax            | 0.2.4  
botorch   | 0.6.2  
gpytorch | 1.6.0  
pytorch   | 1.10.0

I'm interested in getting information about the optim progress, for example, I want to get best_arm_predictions in the following example (it's the developer API example from here ).

The only thing I change is that instead of using a Models.GPEI, I make a Models.BOTORCH_MODULAR by combining Surrogate(SingleTaskGP) as GP and qNoisyExpectedImprovement as acq function.

generator_run.best_arm_predictions returns None in my example whereas it works fine in the original example. Reversely, when I use the original example, gpei.evaluate_acquisition_function(), a NotImplementedError was raised. While in my example, it works, but when I input 2 input points, it only gives back 1 value. Below is the reproduceable code for what I describe.

I want to have access to the training process (current best point), the model, the acq function so that i can do some external calculations. What would be the best way to achieve this?

Thanks

from ax import *
from ax.metrics.branin import BraninMetric
from ax.models.torch.botorch_modular.surrogate import Surrogate
from botorch.models.gp_regression import SingleTaskGP
from botorch.acquisition.monte_carlo import qNoisyExpectedImprovement
from ax.core.observation import ObservationFeatures

class MockRunner(Runner):
    def run(self, trial):
        return {"name": str(trial.index)}


branin_search_space = SearchSpace(
    parameters=[
        RangeParameter(
            name="x1", parameter_type=ParameterType.FLOAT, lower=-5, upper=10
        ),
        RangeParameter(
            name="x2", parameter_type=ParameterType.FLOAT, lower=0, upper=15
        ),
    ]
)

print('----------------- Original Example ----------------')
exp = Experiment(
    name="test_branin",
    search_space=branin_search_space,
    optimization_config=OptimizationConfig(
        objective=Objective(
            metric=BraninMetric(name="branin", param_names=["x1", "x2"]),
            minimize=True,
        ),
    ),
    runner=MockRunner(),
)

sobol = Models.SOBOL(exp.search_space)
for i in range(2):
    trial = exp.new_trial(generator_run=sobol.gen(1))
    trial.run()
    trial.mark_completed()

best_arm = None
for i in range(2):
    gpei = Models.GPEI(experiment=exp, data=exp.fetch_data())
    generator_run = gpei.gen(1)
    best_arm,_ = generator_run.best_arm_predictions
    print('best_arm:',best_arm)
    trial = exp.new_trial(generator_run=generator_run)
    trial.run()
    trial.mark_completed()

print(exp.fetch_data().df)
print('best_arm.parameters:')
try:
    print(best_arm.parameters)
except Exception as e:
    print(str(e))

print('evaluate_acquisition_function():')
Xtest = [ObservationFeatures(parameters={'x1': 8, 'x2': 6}),
        ObservationFeatures(parameters={'x1': 10, 'x2': 1})]

try:
    print(gpei.evaluate_acquisition_function(observation_features=Xtest))
except Exception as e:
    print(str(e))

print('----------------- botorch_modular Example ----------------')

exp = Experiment(
    name="test_branin",
    search_space=branin_search_space,
    optimization_config=OptimizationConfig(
        objective=Objective(
            metric=BraninMetric(name="branin", param_names=["x1", "x2"]),
            minimize=True,
        ),
    ),
    runner=MockRunner(),
)

sobol = Models.SOBOL(exp.search_space)
for i in range(2):
    trial = exp.new_trial(generator_run=sobol.gen(1))
    trial.run()
    trial.mark_completed()

best_arm = None
for i in range(2):
    gpei = Models.BOTORCH_MODULAR(
        experiment=exp, data=exp.fetch_data(),
        surrogate=Surrogate(SingleTaskGP),
        botorch_acqf_class=qNoisyExpectedImprovement
    )
    generator_run = gpei.gen(1)
    best_arm = generator_run.best_arm_predictions
    print('best_arm:',best_arm)
    trial = exp.new_trial(generator_run=generator_run)
    trial.run()
    trial.mark_completed()

print(exp.fetch_data().df)
print('best_arm.parameters:')
try:
    print(best_arm.parameters)
except Exception as e:
    print(str(e))
print('evaluate_acquisition_function():')
try:
    print(gpei.evaluate_acquisition_function(observation_features=Xtest))
except Exception as e:
    print(str(e))

thgngu avatar Sep 14 '22 00:09 thgngu