[FEATURE REQUEST]: Create a "best-practice" method for optimizing mixture weights, including optimizing over a simplex search space
Motivation
A common problem in many domains is to optimize the weights for a mixture of K components. In this case the search space is a simplex. I didn't see this covered in the tutorials/docs (apologies if it is and I missed it) - it would be nice to have a standard recommendation on how to do this in Ax.
-
The simplest approach is to define the parameters as (K-1) weights in (0,1) and set the constraint that their sum <= 1. However, the fraction of the (hypercube) search space that satisfies this constraint decreases exponentially(?) wrt K. It seems like Ax uses rejection sampling to satisfy constraints, which becomes impractical for large K. (I got
SearchSpaceExhausted: Rejection sampling error (specified maximum draws (10000) exhausted, without finding sufficiently many (1) candidates). This likely means that there are no new points left in the search space.when I tried this with just 7 parameters. For more parameters I suspect max_draws would need to be prohibitively large.) -
An alternative approach is to use a mapping (preferably a 1:1 bijection with minimal distortion) between the hypercube and the simplex, as an interface between BO and the objective function. Stick-breaking is one way to do this, but may distort distances near the boundary which could be a problem if the optimum involves setting weights of some components to zero, or if user already has data from such points (which is not uncommon). Other methods like softmax are many:1, meaning that the simplex maps to multiple redundant regions of the search space, which seems like it would make optimization inefficient.
Describe the solution you'd like to see implemented in Ax.
Stick-breaking seems like a good first approach, but there may be better alternatives.
Describe any alternatives you've considered to the above solution.
No response
Is this related to an existing issue in Ax or another repository? If so please include links to those Issues here.
No response
Code of Conduct
- [x] I agree to follow Ax's Code of Conduct
There's a fair amount of discussion about this in previous issues starting from https://github.com/facebook/Ax/issues/727 you should be able to trace most of them.
In my own work I do something along the lines of 1 currently. The challenge I encounter with this type of approach is that it's possible to encounter problems where other you have other constraints that prevent cleanly omitting one dimension to go from K to K-1.
Regarding 1, it's not hard in principle to sample uniformly from the simplex (we implement this in botorch here: https://github.com/pytorch/botorch/blob/main/botorch/utils/sampling.py#L187-L225). Things get a bit more complicated when there are additional constraints, but you can also handle this - in fact we have an option to do this via fallback_to_sample_polytope: https://github.com/facebook/Ax/blob/dbf7beffa233a878ab78c7687549ea3878e47e15/ax/generators/random/base.py#L72, which uses a hit-and-run sampler to sample from the constraint polytope: https://github.com/facebook/Ax/blob/dbf7beffa233a878ab78c7687549ea3878e47e15/ax/generators/random/base.py#L192-L205. Setting that will allow the sampling to fall back: Generators.SOBOL(search_space=search_space, fallback_to_sample_polytope=True)
The issue with that is that we currently don't have good support for imposing equality constraints for the other generators that optimize acquisition functions. In principle that's not too hard, but it gets rather messy to express general constraints properly if there are discrete or categorical parameters or relaxations thereof involved. If none of these complications are there, then it'd be a mostly a matter of doing the engineering to pass this through the full stack and handle it in the right places.
The alternative approach in 2 makes sense - you can see how we do this within the HitAndRunPolytopeSampler here:
https://github.com/pytorch/botorch/blob/main/botorch/utils/sampling.py#L512-L520, which operates in the subspace defined by the equality constraints and then embeds the sampled candidates in the ambient space. But this suffers from similar issues with general search spaces that could contain discrete, fixed, categorical parameters. It would also mean that we'd have to introduce yet another layer of transform to that subspace and back.
- The simplest approach is to define the parameters as (K-1) weights in (0,1) and set the constraint that their sum <= 1. However, the fraction of the (hypercube) search space that satisfies this constraint decreases exponentially(?) wrt K. It seems like Ax uses rejection sampling to satisfy constraints, which becomes impractical for large K. (I got
SearchSpaceExhausted: Rejection sampling error (specified maximum draws (10000) exhausted, without finding sufficiently many (1) candidates). This likely means that there are no new points left in the search space.when I tried this with just 7 parameters. For more parameters I suspect max_draws would need to be prohibitively large.)
A reproducer (e.g., on Colab) might help. I've used this "hidden variable" approach many times, ranging from a few to a couple dozen parameters. I remember one time where SearchSpaceExhausted was thrown due to choosing to hide a variable with a very narrow range. Another straightforward option is to use a large set of predefined candidates that meet the constraints, but you lose out on gains from using gradients during acquisition function optimization (i.e., a brute force approach, which scales quite poorly).
The challenge I encounter with this type of approach is that it's possible to encounter problems where other you have other constraints that prevent cleanly omitting one dimension to go from K to K-1.
I remember needing to do something like this here due to having both composition and permutation invariance constraints (latter represented via order constraints). See Appendix B from https://doi.org/10.1016/j.commatsci.2023.112134 [preprint]. Mathematica or other symbolic solver approaches can be a good spot check when there are several constraints involved.
@sgbaird Here's a minimal reproducible example. If you have a working example, would be interested to see it! Theoretically, it feels like rejection sampling should fail like this in high dimensions b/c you're trying to get a draw from a very small region of the hypercube.
@Balandat Thanks for the context! My two cents - it would be quite valuable to have explicit simplex sampling in Ax, even if it doesn't work with some other constraints, just because mixtures are a common class of "optimize over many parameters" problems. I'm trying (2) rn but I'm not sure whether or not it's distorting the space to a degree that is harmful.
@kurt-essential, from what I could tell, the last variable wasn't "hidden". See https://colab.research.google.com/drive/1c40f9K9ej-QaLrzVQQHPwSvCixl6SOjK?usp=sharing
Copied below for provenance
# !pip install ax-platform
import numpy as np
from ax.api.client import Client
from ax.api.configs import RangeParameterConfig
client = Client()
n_params = 5 # Only optimize x1 to x5, x6 will be calculated
# Expanded version of parameter creation:
parameters = [
RangeParameterConfig(name="x1", parameter_type="float", bounds=(0, 1)),
RangeParameterConfig(name="x2", parameter_type="float", bounds=(0, 1)),
RangeParameterConfig(name="x3", parameter_type="float", bounds=(0, 1)),
RangeParameterConfig(name="x4", parameter_type="float", bounds=(0, 1)),
RangeParameterConfig(name="x5", parameter_type="float", bounds=(0, 1))
]
# Expanded version of sum constraint:
sum_constraint = "x1 + x2 + x3 + x4 + x5 <= 1.0"
client.configure_experiment(parameters=parameters,
parameter_constraints = [sum_constraint]
)
metric_name = "hartmann6" # this name is used during the optimization loop in Step 5
objective = f"-{metric_name}" # minimization is specified by the negative sign
client.configure_optimization(objective=objective)
# Note: not sure how to increase max_draws to avoid SearchSpaceExhausted
# client.configure_generation_strategy(max_rs_draws=200000)
# Hartmann6 function
def hartmann6(x1, x2, x3, x4, x5, x6):
alpha = np.array([1.0, 1.2, 3.0, 3.2])
A = np.array([
[10, 3, 17, 3.5, 1.7, 8],
[0.05, 10, 17, 0.1, 8, 14],
[3, 3.5, 1.7, 10, 17, 8],
[17, 8, 0.05, 10, 0.1, 14]
])
P = 10**-4 * np.array([
[1312, 1696, 5569, 124, 8283, 5886],
[2329, 4135, 8307, 3736, 1004, 9991],
[2348, 1451, 3522, 2883, 3047, 6650],
[4047, 8828, 8732, 5743, 1091, 381]
])
outer = 0.0
for i in range(4):
inner = 0.0
for j, x in enumerate([x1, x2, x3, x4, x5, x6]):
inner += A[i, j] * (x - P[i, j])**2
outer += alpha[i] * np.exp(-inner)
return -outer
# Wrapper function that calculates x6 and calls hartmann6
def evaluate_hartmann6(x1, x2, x3, x4, x5):
"""
Calculate x6 as the remainder to make the sum equal to 1,
then evaluate the hartmann6 function.
"""
x6 = 1.0 - (x1 + x2 + x3 + x4 + x5)
return hartmann6(x1, x2, x3, x4, x5, x6)
# Test the function
result = evaluate_hartmann6(0.1, 0.3, 0.2, 0.15, 0.05)
print(f"Function result: {result}")
for _ in range(10): # Run 10 rounds of trials
# We will request three trials at a time in this example
trials = client.get_next_trials(max_trials=3)
for trial_index, parameters in trials.items():
x1 = parameters["x1"]
x2 = parameters["x2"]
x3 = parameters["x3"]
x4 = parameters["x4"]
x5 = parameters["x5"]
result = evaluate_hartmann6(x1, x2, x3, x4, x5)
# Set raw_data as a dictionary with metric names as keys and results as values
raw_data = {metric_name: result}
# Complete the trial with the result
client.complete_trial(trial_index=trial_index, raw_data=raw_data)
print(f"Completed trial {trial_index} with {raw_data=}")
@sgbaird
In my example I constrained x1:x8 to sum to <= 1. You can assume WLOG that there is a "hidden" x9. (Of course the objective function only depends on x1:x6 but that shouldn't be relevant to the matter at hand.) I found that once you increase the search space to ~8 parameters you get SearchSpaceExhausted. The example in your post only has 5 parameters - I also saw that it works at that scale, just not at larger scales.
n_params = 8
parameters = [
RangeParameterConfig(
name=f"x{i}", parameter_type="float", bounds=(0, 1)
) for i in range(1, n_params+1)
]
sum_constraint = " + ".join([f"x{i}" for i in range(1, n_params+1)]) + " <= 1.0"
I found that once you increase the search space to ~8 parameters you get SearchSpaceExhausted. The example in your post only has 5 parameters - I also saw that it works at that scale, just not at larger scales.
Do you have a repro with a stack trace for this?
This is likely b/c we're naively sampling from the constrained space x1 + ... + x8 <= 1.0 using rejection sampling The fraction of the feasible volume for this is 1/8!
We haven a fallback_to_sample_polytope option in our random generator that will use a different method to draw random samples which should not result in this error: https://github.com/facebook/Ax/blob/b2adff918719b742d4c17c3cfcf16d7ae8acae5b/ax/generators/random/base.py#L62
It's actually not hard to sample uniformly from the simplex directly (BoTorch implementation here: https://github.com/pytorch/botorch/blob/main/botorch/utils/sampling.py#L195), but without having this simplex constraint represented as a first-class object in Ax it's not easy to auto-dispatch to that.
cc @sdaulton re considerations for simplex constraints with DerivedParameters.
Here's the stack trace from the reproducer that @kurt-essential shared. That was my bad for getting confused and reducing to a 5D problem. Copying the stack trace from the colab repro that Kurt shared.
---------------------------------------------------------------------------
SearchSpaceExhausted Traceback (most recent call last)
<ipython-input-14-782463899> in <cell line: 0>()
1 for _ in range(10): # Run 10 rounds of trials
2 # We will request three trials at a time in this example
----> 3 trials = client.get_next_trials(max_trials=3)
4
5 for trial_index, parameters in trials.items():
14 frames
/usr/local/lib/python3.11/dist-packages/ax/api/client.py in get_next_trials(self, max_trials, fixed_parameters)
371
372 # This will be changed to use gen directly post gen-unfication cc @mgarrard
--> 373 generator_runs = gs.gen_for_multiple_trials_with_multiple_models(
374 experiment=self._experiment,
375 pending_observations=(
/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_strategy.py in gen_for_multiple_trials_with_multiple_models(self, experiment, data, pending_observations, n, fixed_features, num_trials, arms_per_node)
426 for _i in range(num_trials):
427 grs_for_multiple_trials.append(
--> 428 self._gen_with_multiple_nodes(
429 experiment=experiment,
430 data=data,
/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_strategy.py in _gen_with_multiple_nodes(self, experiment, n, pending_observations, data, fixed_features, arms_per_node, first_generation_in_multi)
776 transitioned = self._maybe_transition_to_next_node()
777 try:
--> 778 gr = self._curr.gen(
779 experiment=experiment,
780 data=data,
/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in gen(self, experiment, pending_observations, skip_fit, data, **gs_gen_kwargs)
443 )
444 except Exception as e:
--> 445 gr = self._try_gen_with_fallback(
446 exception=e,
447 experiment=experiment,
/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in _try_gen_with_fallback(self, exception, experiment, n, data, pending_observations, **model_gen_kwargs)
564 error_type = type(exception)
565 if error_type not in self.fallback_specs:
--> 566 raise exception
567
568 # identify fallback model to use
/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in gen(self, experiment, pending_observations, skip_fit, data, **gs_gen_kwargs)
436 # duplicate of a previous active arm (e.g. not from a failed trial)
437 # on the experiment.
--> 438 gr = self._gen_maybe_deduplicate(
439 experiment=experiment,
440 data=data,
/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in _gen_maybe_deduplicate(self, experiment, n, pending_observations, data, **model_gen_kwargs)
528 while n_gen_draws < MAX_GEN_ATTEMPTS:
529 n_gen_draws += 1
--> 530 gr = self._gen(
531 experiment=experiment,
532 data=data,
/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/generation_node.py in _gen(self, experiment, n, pending_observations, data, **model_gen_kwargs)
496 # override the one set in `model_spec.model_gen_kwargs`.
497 n = model_spec.model_gen_kwargs.get("n", None)
--> 498 return model_spec.gen(
499 experiment=experiment,
500 data=data,
/usr/local/lib/python3.11/dist-packages/ax/generation_strategy/model_spec.py in gen(self, **model_gen_kwargs)
236 # copy to ensure there is no in-place modification
237 model_gen_kwargs = deepcopy(model_gen_kwargs)
--> 238 generator_run = fitted_model.gen(**model_gen_kwargs)
239 fit_and_std_quality_and_generalization_dict = (
240 get_fit_and_std_quality_and_generalization_dict(
/usr/local/lib/python3.11/dist-packages/ax/modelbridge/base.py in gen(self, n, search_space, optimization_config, pending_observations, fixed_features, model_gen_options)
806 )
807 # Apply terminal transform and gen
--> 808 gen_results = self._gen(
809 n=n,
810 search_space=base_gen_args.search_space,
/usr/local/lib/python3.11/dist-packages/ax/modelbridge/random.py in _gen(self, n, search_space, pending_observations, fixed_features, optimization_config, model_gen_options)
125
126 # Generate the candidates
--> 127 X, w = self.model.gen(
128 n=n,
129 bounds=search_space_digest.bounds,
/usr/local/lib/python3.11/dist-packages/ax/models/random/sobol.py in gen(self, n, bounds, linear_constraints, fixed_features, model_gen_options, rounding_func, generated_points)
109 if len(tf_indices) > 0:
110 self.init_engine(len(tf_indices))
--> 111 points, weights = super().gen(
112 n=n,
113 bounds=bounds,
/usr/local/lib/python3.11/dist-packages/ax/models/random/base.py in gen(self, n, bounds, linear_constraints, fixed_features, model_gen_options, rounding_func, generated_points)
195 # TODO: Should this round & deduplicate?
196 else:
--> 197 raise e
198
199 self.attempted_draws = attempted_draws
/usr/local/lib/python3.11/dist-packages/ax/models/random/base.py in gen(self, n, bounds, linear_constraints, fixed_features, model_gen_options, rounding_func, generated_points)
146 # constraints or actual duplicates and deduplicate is specified.
147 # If rejection sampling fails, fall back to polytope sampling
--> 148 points, attempted_draws = rejection_sample(
149 gen_unconstrained=self._gen_unconstrained,
150 n=n,
/usr/local/lib/python3.11/dist-packages/ax/models/model_utils.py in rejection_sample(gen_unconstrained, n, d, tunable_feature_indices, linear_constraints, deduplicate, max_draws, fixed_features, rounding_func, existing_points)
158 if successful_draws < n:
159 # Only possible if attempted_draws >= max_draws.
--> 160 raise SearchSpaceExhausted(
161 f"Rejection sampling error (specified maximum draws ({max_draws}) exhausted"
162 f", without finding sufficiently many ({n}) candidates). This likely means "
SearchSpaceExhausted: Rejection sampling error (specified maximum draws (10000) exhausted, without finding sufficiently many (1) candidates). This likely means that there are no new points left in the search space.
Yeah definitely rejection sampling to blame. Using fallback_to_sample_polytope should be the quickest way around, see my comment here: https://github.com/facebook/Ax/issues/2510#issuecomment-2160828417
@Balandat yep that's what I suspected :) see my colab example above
FYI support for a DerivedParameter is added in https://github.com/facebook/Ax/pull/4142. This would allow specifying x6 as a function of the other parameters and using a sum constraint that x1 + ... + x5 <= 1.
Conceptually, in the case that this is the only constraint, it seems like we could just using sample_simplex (with one additional point) internally to generate points uniformly and avoid the issues with rejection sampling and the need for using long chains with the hit and run sampler but as @Balandat it would take some care to properly dispatch to that.
With the Derived Parameter can we put constraints on the derived parameter value? i.e. can I do something like x1 + ... + x5 <= 1, x1+x6<0.2, x6= 1 - x1 + ... + x5?
no, for now we require that the parameter constraints be specified in terms of the other parameters. That's a simple re-write of the constraint though
FYI support for a
DerivedParameteris added in #4142. This would allow specifying x6 as a function of the other parameters and using a sum constraint thatx1 + ... + x5 <= 1.Conceptually, in the case that this is the only constraint, it seems like we could just using
sample_simplex(with one additional point) internally to generate points uniformly and avoid the issues with rejection sampling and the need for using long chains with the hit and run sampler but as @Balandat it would take some care to properly dispatch to that.
@sdaulton this is great to see! Could you comment a bit - in its current implementation, is this purely for plumbing convenience or does this already affect surrogate modeling (underlying kernel) and/or visualization (e.g., feature importance plots)?