Ax icon indicating copy to clipboard operation
Ax copied to clipboard

[FEATURE REQUEST]: Create a "best-practice" method for optimizing mixture weights, including optimizing over a simplex search space

Open kurt-essential opened this issue 6 months ago • 14 comments

Motivation

A common problem in many domains is to optimize the weights for a mixture of K components. In this case the search space is a simplex. I didn't see this covered in the tutorials/docs (apologies if it is and I missed it) - it would be nice to have a standard recommendation on how to do this in Ax.

  1. The simplest approach is to define the parameters as (K-1) weights in (0,1) and set the constraint that their sum <= 1. However, the fraction of the (hypercube) search space that satisfies this constraint decreases exponentially(?) wrt K. It seems like Ax uses rejection sampling to satisfy constraints, which becomes impractical for large K. (I got SearchSpaceExhausted: Rejection sampling error (specified maximum draws (10000) exhausted, without finding sufficiently many (1) candidates). This likely means that there are no new points left in the search space. when I tried this with just 7 parameters. For more parameters I suspect max_draws would need to be prohibitively large.)

  2. An alternative approach is to use a mapping (preferably a 1:1 bijection with minimal distortion) between the hypercube and the simplex, as an interface between BO and the objective function. Stick-breaking is one way to do this, but may distort distances near the boundary which could be a problem if the optimum involves setting weights of some components to zero, or if user already has data from such points (which is not uncommon). Other methods like softmax are many:1, meaning that the simplex maps to multiple redundant regions of the search space, which seems like it would make optimization inefficient.

Describe the solution you'd like to see implemented in Ax.

Stick-breaking seems like a good first approach, but there may be better alternatives.

Describe any alternatives you've considered to the above solution.

No response

Is this related to an existing issue in Ax or another repository? If so please include links to those Issues here.

No response

Code of Conduct

  • [x] I agree to follow Ax's Code of Conduct

kurt-essential avatar Jun 04 '25 17:06 kurt-essential

There's a fair amount of discussion about this in previous issues starting from https://github.com/facebook/Ax/issues/727 you should be able to trace most of them.

In my own work I do something along the lines of 1 currently. The challenge I encounter with this type of approach is that it's possible to encounter problems where other you have other constraints that prevent cleanly omitting one dimension to go from K to K-1.

CompRhys avatar Jun 07 '25 00:06 CompRhys

Regarding 1, it's not hard in principle to sample uniformly from the simplex (we implement this in botorch here: https://github.com/pytorch/botorch/blob/main/botorch/utils/sampling.py#L187-L225). Things get a bit more complicated when there are additional constraints, but you can also handle this - in fact we have an option to do this via fallback_to_sample_polytope: https://github.com/facebook/Ax/blob/dbf7beffa233a878ab78c7687549ea3878e47e15/ax/generators/random/base.py#L72, which uses a hit-and-run sampler to sample from the constraint polytope: https://github.com/facebook/Ax/blob/dbf7beffa233a878ab78c7687549ea3878e47e15/ax/generators/random/base.py#L192-L205. Setting that will allow the sampling to fall back: Generators.SOBOL(search_space=search_space, fallback_to_sample_polytope=True)

The issue with that is that we currently don't have good support for imposing equality constraints for the other generators that optimize acquisition functions. In principle that's not too hard, but it gets rather messy to express general constraints properly if there are discrete or categorical parameters or relaxations thereof involved. If none of these complications are there, then it'd be a mostly a matter of doing the engineering to pass this through the full stack and handle it in the right places.

The alternative approach in 2 makes sense - you can see how we do this within the HitAndRunPolytopeSampler here: https://github.com/pytorch/botorch/blob/main/botorch/utils/sampling.py#L512-L520, which operates in the subspace defined by the equality constraints and then embeds the sampled candidates in the ambient space. But this suffers from similar issues with general search spaces that could contain discrete, fixed, categorical parameters. It would also mean that we'd have to introduce yet another layer of transform to that subspace and back.

Balandat avatar Jun 07 '25 13:06 Balandat

  1. The simplest approach is to define the parameters as (K-1) weights in (0,1) and set the constraint that their sum <= 1. However, the fraction of the (hypercube) search space that satisfies this constraint decreases exponentially(?) wrt K. It seems like Ax uses rejection sampling to satisfy constraints, which becomes impractical for large K. (I got SearchSpaceExhausted: Rejection sampling error (specified maximum draws (10000) exhausted, without finding sufficiently many (1) candidates). This likely means that there are no new points left in the search space. when I tried this with just 7 parameters. For more parameters I suspect max_draws would need to be prohibitively large.)

A reproducer (e.g., on Colab) might help. I've used this "hidden variable" approach many times, ranging from a few to a couple dozen parameters. I remember one time where SearchSpaceExhausted was thrown due to choosing to hide a variable with a very narrow range. Another straightforward option is to use a large set of predefined candidates that meet the constraints, but you lose out on gains from using gradients during acquisition function optimization (i.e., a brute force approach, which scales quite poorly).

The challenge I encounter with this type of approach is that it's possible to encounter problems where other you have other constraints that prevent cleanly omitting one dimension to go from K to K-1.

I remember needing to do something like this here due to having both composition and permutation invariance constraints (latter represented via order constraints). See Appendix B from https://doi.org/10.1016/j.commatsci.2023.112134 [preprint]. Mathematica or other symbolic solver approaches can be a good spot check when there are several constraints involved.

sgbaird avatar Jun 10 '25 14:06 sgbaird

@sgbaird Here's a minimal reproducible example. If you have a working example, would be interested to see it! Theoretically, it feels like rejection sampling should fail like this in high dimensions b/c you're trying to get a draw from a very small region of the hypercube.

@Balandat Thanks for the context! My two cents - it would be quite valuable to have explicit simplex sampling in Ax, even if it doesn't work with some other constraints, just because mixtures are a common class of "optimize over many parameters" problems. I'm trying (2) rn but I'm not sure whether or not it's distorting the space to a degree that is harmful.

kurt-essential avatar Jun 11 '25 07:06 kurt-essential

@kurt-essential, from what I could tell, the last variable wasn't "hidden". See https://colab.research.google.com/drive/1c40f9K9ej-QaLrzVQQHPwSvCixl6SOjK?usp=sharing

Copied below for provenance

# !pip install ax-platform
import numpy as np
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig

client = Client()

n_params = 5  # Only optimize x1 to x5, x6 will be calculated

# Expanded version of parameter creation:
parameters = [
    RangeParameterConfig(name="x1", parameter_type="float", bounds=(0, 1)),
    RangeParameterConfig(name="x2", parameter_type="float", bounds=(0, 1)),
    RangeParameterConfig(name="x3", parameter_type="float", bounds=(0, 1)),
    RangeParameterConfig(name="x4", parameter_type="float", bounds=(0, 1)),
    RangeParameterConfig(name="x5", parameter_type="float", bounds=(0, 1))
]

# Expanded version of sum constraint:
sum_constraint = "x1 + x2 + x3 + x4 + x5 <= 1.0"

client.configure_experiment(parameters=parameters,
                            parameter_constraints = [sum_constraint]
                            )

metric_name = "hartmann6" # this name is used during the optimization loop in Step 5
objective = f"-{metric_name}" # minimization is specified by the negative sign

client.configure_optimization(objective=objective)

# Note: not sure how to increase max_draws to avoid SearchSpaceExhausted
# client.configure_generation_strategy(max_rs_draws=200000)

# Hartmann6 function
def hartmann6(x1, x2, x3, x4, x5, x6):
    alpha = np.array([1.0, 1.2, 3.0, 3.2])
    A = np.array([
        [10, 3, 17, 3.5, 1.7, 8],
        [0.05, 10, 17, 0.1, 8, 14],
        [3, 3.5, 1.7, 10, 17, 8],
        [17, 8, 0.05, 10, 0.1, 14]
    ])
    P = 10**-4 * np.array([
        [1312, 1696, 5569, 124, 8283, 5886],
        [2329, 4135, 8307, 3736, 1004, 9991],
        [2348, 1451, 3522, 2883, 3047, 6650],
        [4047, 8828, 8732, 5743, 1091, 381]
    ])

    outer = 0.0
    for i in range(4):
        inner = 0.0
        for j, x in enumerate([x1, x2, x3, x4, x5, x6]):
            inner += A[i, j] * (x - P[i, j])**2
        outer += alpha[i] * np.exp(-inner)
    return -outer

# Wrapper function that calculates x6 and calls hartmann6
def evaluate_hartmann6(x1, x2, x3, x4, x5):
    """
    Calculate x6 as the remainder to make the sum equal to 1,
    then evaluate the hartmann6 function.
    """
    x6 = 1.0 - (x1 + x2 + x3 + x4 + x5)
    return hartmann6(x1, x2, x3, x4, x5, x6)

# Test the function
result = evaluate_hartmann6(0.1, 0.3, 0.2, 0.15, 0.05)
print(f"Function result: {result}")

for _ in range(10): # Run 10 rounds of trials
    # We will request three trials at a time in this example
    trials = client.get_next_trials(max_trials=3)

    for trial_index, parameters in trials.items():
        x1 = parameters["x1"]
        x2 = parameters["x2"]
        x3 = parameters["x3"]
        x4 = parameters["x4"]
        x5 = parameters["x5"]

        result = evaluate_hartmann6(x1, x2, x3, x4, x5)

        # Set raw_data as a dictionary with metric names as keys and results as values
        raw_data = {metric_name: result}

        # Complete the trial with the result
        client.complete_trial(trial_index=trial_index, raw_data=raw_data)

        print(f"Completed trial {trial_index} with {raw_data=}")

sgbaird avatar Aug 18 '25 20:08 sgbaird

@sgbaird In my example I constrained x1:x8 to sum to <= 1. You can assume WLOG that there is a "hidden" x9. (Of course the objective function only depends on x1:x6 but that shouldn't be relevant to the matter at hand.) I found that once you increase the search space to ~8 parameters you get SearchSpaceExhausted. The example in your post only has 5 parameters - I also saw that it works at that scale, just not at larger scales.

n_params = 8
parameters = [
    RangeParameterConfig(
        name=f"x{i}", parameter_type="float", bounds=(0, 1)
    ) for i in range(1, n_params+1)
]
sum_constraint = " + ".join([f"x{i}" for i in range(1, n_params+1)]) + " <= 1.0"

kurt-essential avatar Aug 18 '25 21:08 kurt-essential

I found that once you increase the search space to ~8 parameters you get SearchSpaceExhausted. The example in your post only has 5 parameters - I also saw that it works at that scale, just not at larger scales.

Do you have a repro with a stack trace for this?

This is likely b/c we're naively sampling from the constrained space x1 + ... + x8 <= 1.0 using rejection sampling The fraction of the feasible volume for this is 1/8!

We haven a fallback_to_sample_polytope option in our random generator that will use a different method to draw random samples which should not result in this error: https://github.com/facebook/Ax/blob/b2adff918719b742d4c17c3cfcf16d7ae8acae5b/ax/generators/random/base.py#L62

It's actually not hard to sample uniformly from the simplex directly (BoTorch implementation here: https://github.com/pytorch/botorch/blob/main/botorch/utils/sampling.py#L195), but without having this simplex constraint represented as a first-class object in Ax it's not easy to auto-dispatch to that.

cc @sdaulton re considerations for simplex constraints with DerivedParameters.

Balandat avatar Aug 18 '25 23:08 Balandat

Here's the stack trace from the reproducer that @kurt-essential shared. That was my bad for getting confused and reducing to a 5D problem. Copying the stack trace from the colab repro that Kurt shared.

---------------------------------------------------------------------------
SearchSpaceExhausted                      Traceback (most recent call last)
<ipython-input-14-782463899> in <cell line: 0>()
      1 for _ in range(10): # Run 10 rounds of trials
      2     # We will request three trials at a time in this example
----> 3     trials = client.get_next_trials(max_trials=3)
      4 
      5     for trial_index, parameters in trials.items():

14 frames
/usr/local/lib/python3.11/dist-packages/ax/api/client.py in get_next_trials(self, max_trials, fixed_parameters)
    371 
    372             # This will be changed to use gen directly post gen-unfication cc @mgarrard
--> 373             generator_runs = gs.gen_for_multiple_trials_with_multiple_models(
    374                 experiment=self._experiment,
    375                 pending_observations=(

/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_strategy.py in gen_for_multiple_trials_with_multiple_models(self, experiment, data, pending_observations, n, fixed_features, num_trials, arms_per_node)
    426         for _i in range(num_trials):
    427             grs_for_multiple_trials.append(
--> 428                 self._gen_with_multiple_nodes(
    429                     experiment=experiment,
    430                     data=data,

/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_strategy.py in _gen_with_multiple_nodes(self, experiment, n, pending_observations, data, fixed_features, arms_per_node, first_generation_in_multi)
    776             transitioned = self._maybe_transition_to_next_node()
    777             try:
--> 778                 gr = self._curr.gen(
    779                     experiment=experiment,
    780                     data=data,

/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in gen(self, experiment, pending_observations, skip_fit, data, **gs_gen_kwargs)
    443             )
    444         except Exception as e:
--> 445             gr = self._try_gen_with_fallback(
    446                 exception=e,
    447                 experiment=experiment,

/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in _try_gen_with_fallback(self, exception, experiment, n, data, pending_observations, **model_gen_kwargs)
    564         error_type = type(exception)
    565         if error_type not in self.fallback_specs:
--> 566             raise exception
    567 
    568         # identify fallback model to use

/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in gen(self, experiment, pending_observations, skip_fit, data, **gs_gen_kwargs)
    436             # duplicate of a previous active arm (e.g. not from a failed trial)
    437             # on the experiment.
--> 438             gr = self._gen_maybe_deduplicate(
    439                 experiment=experiment,
    440                 data=data,

/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in _gen_maybe_deduplicate(self, experiment, n, pending_observations, data, **model_gen_kwargs)
    528         while n_gen_draws < MAX_GEN_ATTEMPTS:
    529             n_gen_draws += 1
--> 530             gr = self._gen(
    531                 experiment=experiment,
    532                 data=data,

/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in _gen(self, experiment, n, pending_observations, data, **model_gen_kwargs)
    496             # override the one set in `model_spec.model_gen_kwargs`.
    497             n = model_spec.model_gen_kwargs.get("n", None)
--> 498         return model_spec.gen(
    499             experiment=experiment,
    500             data=data,

/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/model_spec.py in gen(self, **model_gen_kwargs)
    236         # copy to ensure there is no in-place modification
    237         model_gen_kwargs = deepcopy(model_gen_kwargs)
--> 238         generator_run = fitted_model.gen(**model_gen_kwargs)
    239         fit_and_std_quality_and_generalization_dict = (
    240             get_fit_and_std_quality_and_generalization_dict(

/usr/local/lib/python3.11/dist-packages/ax/modelbridge/base.py in gen(self, n, search_space, optimization_config, pending_observations, fixed_features, model_gen_options)
    806         )
    807         # Apply terminal transform and gen
--> 808         gen_results = self._gen(
    809             n=n,
    810             search_space=base_gen_args.search_space,

/usr/local/lib/python3.11/dist-packages/ax/modelbridge/random.py in _gen(self, n, search_space, pending_observations, fixed_features, optimization_config, model_gen_options)
    125 
    126         # Generate the candidates
--> 127         X, w = self.model.gen(
    128             n=n,
    129             bounds=search_space_digest.bounds,

/usr/local/lib/python3.11/dist-packages/ax/models/random/sobol.py in gen(self, n, bounds, linear_constraints, fixed_features, model_gen_options, rounding_func, generated_points)
    109         if len(tf_indices) > 0:
    110             self.init_engine(len(tf_indices))
--> 111         points, weights = super().gen(
    112             n=n,
    113             bounds=bounds,

/usr/local/lib/python3.11/dist-packages/ax/models/random/base.py in gen(self, n, bounds, linear_constraints, fixed_features, model_gen_options, rounding_func, generated_points)
    195                 # TODO: Should this round & deduplicate?
    196             else:
--> 197                 raise e
    198 
    199         self.attempted_draws = attempted_draws

/usr/local/lib/python3.11/dist-packages/ax/models/random/base.py in gen(self, n, bounds, linear_constraints, fixed_features, model_gen_options, rounding_func, generated_points)
    146             # constraints or actual duplicates and deduplicate is specified.
    147             # If rejection sampling fails, fall back to polytope sampling
--> 148             points, attempted_draws = rejection_sample(
    149                 gen_unconstrained=self._gen_unconstrained,
    150                 n=n,

/usr/local/lib/python3.11/dist-packages/ax/models/model_utils.py in rejection_sample(gen_unconstrained, n, d, tunable_feature_indices, linear_constraints, deduplicate, max_draws, fixed_features, rounding_func, existing_points)
    158     if successful_draws < n:
    159         # Only possible if attempted_draws >= max_draws.
--> 160         raise SearchSpaceExhausted(
    161             f"Rejection sampling error (specified maximum draws ({max_draws}) exhausted"
    162             f", without finding sufficiently many ({n}) candidates). This likely means "

SearchSpaceExhausted: Rejection sampling error (specified maximum draws (10000) exhausted, without finding sufficiently many (1) candidates). This likely means that there are no new points left in the search space.

sgbaird avatar Aug 18 '25 23:08 sgbaird

Yeah definitely rejection sampling to blame. Using fallback_to_sample_polytope should be the quickest way around, see my comment here: https://github.com/facebook/Ax/issues/2510#issuecomment-2160828417

Balandat avatar Aug 18 '25 23:08 Balandat

@Balandat yep that's what I suspected :) see my colab example above

kurt-essential avatar Aug 19 '25 23:08 kurt-essential

FYI support for a DerivedParameter is added in https://github.com/facebook/Ax/pull/4142. This would allow specifying x6 as a function of the other parameters and using a sum constraint that x1 + ... + x5 <= 1.

Conceptually, in the case that this is the only constraint, it seems like we could just using sample_simplex (with one additional point) internally to generate points uniformly and avoid the issues with rejection sampling and the need for using long chains with the hit and run sampler but as @Balandat it would take some care to properly dispatch to that.

sdaulton avatar Aug 25 '25 22:08 sdaulton

With the Derived Parameter can we put constraints on the derived parameter value? i.e. can I do something like x1 + ... + x5 <= 1, x1+x6<0.2, x6= 1 - x1 + ... + x5?

CompRhys avatar Aug 26 '25 15:08 CompRhys

no, for now we require that the parameter constraints be specified in terms of the other parameters. That's a simple re-write of the constraint though

sdaulton avatar Aug 26 '25 16:08 sdaulton

FYI support for a DerivedParameter is added in #4142. This would allow specifying x6 as a function of the other parameters and using a sum constraint that x1 + ... + x5 <= 1.

Conceptually, in the case that this is the only constraint, it seems like we could just using sample_simplex (with one additional point) internally to generate points uniformly and avoid the issues with rejection sampling and the need for using long chains with the hit and run sampler but as @Balandat it would take some care to properly dispatch to that.

@sdaulton this is great to see! Could you comment a bit - in its current implementation, is this purely for plumbing convenience or does this already affect surrogate modeling (underlying kernel) and/or visualization (e.g., feature importance plots)?

sgbaird avatar Aug 26 '25 16:08 sgbaird