bofire
bofire copied to clipboard
Incorporate 'Candidate Selection' DoE method(s)
Use case: Have existing candidates -- in our example, it was a shelf full of additives -- and have budget for a set number X of experiments. Want to 'cover'/explore your space enough given the fixed number of these candidates.
Solutions to this problem fall under the description 'candidate approach', and are not yet implemented in bofire.
Some info on possible implementations: (R packages being referenced due to the context of chatting with statisticians ;) )
- clustering and choose a candidate per each cluster, X clusters [My knee-jerk reaction was this. I am not a statistician, and this is not a DoE, so probably shouldn't be implemented]
- calculate the I-criterion for each set of X points from your total candidate set, rank by score, pick best (see https://github.com/experimental-design/bofire/issues/337). Should be fast, depending on size of X and number of total candidates [Chemist and Physicist suggestion]
- if your candidates are a grid of possible values in each dimension, there are well established coordinate exchange algorithms that are fast and do a good job for this. The R package with this functionality is algdesign. [Statistician suggestion]
- can do the subset selection using genetic algorithm (see this R package from our colleague ) [Statistician suggestion]
A first approach would involving picking a method to implement, future could be allowing multiple method selection (so don't block that when you solve this).
@dlinzner-bcs is this enough?
Just an idea, but maybe a starting point for solving this idk... one could multiply each candidate in the design matrix with a weight $w_i$ (ultimately the weights should be binary numbers). Then one can optimize a criterion (e.g. D-optimality) w.r.t. some constraint $\sum_i w_i^2 = \sum_i w_i \leq n_{max}$ where $n_{max}$ is the maximum number of affordable experiments (first equality only holds for non-relaxed problem).
Assuming continuous weights, this is even a convex problem in case of D-optimality and can be solved like this
# generate candidates
domain = Domain(
inputs=[
ContinuousInput(key="x1", bounds=(0, 1)),
ContinuousInput(key="x2", bounds=(0, 1)),
ContinuousInput(key="x3", bounds=(0, 1)),
],
outputs=[ContinuousOutput(key="y")],
)
data_model = DoEStrategy(
domain=domain,
criterion=DOptimalityCriterion(formula="fully-quadratic"),
ipopt_options={"disp": 0},
)
strategy = strategies.map(data_model=data_model)
candidates = strategy.ask(candidate_count=20)
# solve candidate selection problem
import cvxpy as cp
import numpy as np
X = candidates.to_numpy()
n_max = 18
x = cp.Variable(X.shape[0])
objective = cp.Maximize(cp.log_det(X.T @ cp.diag(x) @ X))
constraints = [cp.sum(x) <= n_max, x >= 0, x <= 1]
prob = cp.Problem(objective, constraints)
result = prob.solve(solver=cp.installed_solvers()[1])
print(np.round(x.value, 3))
We would then only need a way to get back from the relaxed problem (with the continuous weights) to the discrete one. @dlinzner-bcs you worked on optimization with discrete variables, right? You have an idea how to solve this problem given the fact we can solve the relaxed one?
@rosonaeldred : does this roughly go into the direction you imagine?
Cheers Aaron :)
I really like the idea with the weights!
I will use this issue now to describe the current approach and problems with it for all people included in the email discussion @WaStCo @rosonaeldred @dlinzner-bcs --> maybe you have ideas how to fix the current problems or you have better ideas how to approach the candidate selection. I thought a bit longer about an integer programming approach to general DoE problems, with the candidate selection being a special case of it.
Lets assume we have a grid $G$ consisting of $n$ points in the design space (One should think of this as a regular grid covering the full design space densly, or, alternatively, just a set of candidates), which might contain discrete and continuous decision variables.
Let $X \in \mathbb{R}^{n,m}$ be the model matrix for all points in $G$ and a model with $m$ terms. For a tuple of non-negative integers $K:=(k_1, ...., k_n)$, $X_K \in \mathbb{R}^{(k_1+...+k_n),m}$ is a matrix derived from $X$, where each row is repeated $k_i$ times where $i$ is the row index (for $k_i = 0$ the i-th row is removed). I.e. $X_K$ the model matrix for a design where each point in $G$ occurs $k_i$ times and $i$ is the i-th point in $G$. For a densly sampled grid of the design space the model matrices $X_K$ that can be formed by choosing $K$ arbitrarily such that $k_1 + ... k_n = N$ are an approximation of all possible model matrices derived from designs with $N$ experiments.
I think it is easy to show (should directly follow from multilinearity of the determinant) that for the D-criterion we have $log(det(X^T W_K X)) = log(det(X_K^T X_K))$, where $W_K = diag(k_1, ..., k_n)$.
Therefore, to find a D-optimal design with $N$ experiments we can simply solve the following optimization problem
$$ min_K log(det(X^T W_K X)) $$ subject to $$ \sum_i k_i = N $$ $$ K \in \mathbb{N_0}^n $$
This is an optimization problem which has only one optimization variable per candidate and it is also valid for discrete design variables, and the problem is even convex (except for the integer requirement). However, in high dimension the number of candidates in dense grid (of continuous variables) will of course explode, therefore this approach can, if at all, only work for relatively low dimensional problems.
When dropping the integer constraint, I found in small tests that the relaxed problem with around 1k decision variables can still be solved quickly in cvxpy or ipopt. But the only MINLP solver (which supports solving the non-relaxed problem) that I know is MindtPy, which has no native support for the log_det and with my own implementation it takes eternities for for MindtPy to compile (at least this is what I think :D) the optimization problems for anything beyond toy-problem size.
Now back to the candidate selection problem: to solve this type of problem just set $G$ to your set of candidates, $N$ to the maximum number of candidates you want to select and set the bounds of the components of $K$ to [0, 1] (instead of any non-negative integer).
Do you guys know a MINLP solver that could solve this problem quickly for somewhat bigger problems (at least with a few hundreds to thousands of candidates in $G$)? Otherwise one could also try to implement a branch a bound algorithm where a sequence of relaxed problems is solved using cvxpy, but this would be some work ...
Here is also the code for the toy example with MindtPy. Here we have a linear model with 2 design variables $x_1$ and $x_2$ and we want to get a design with 12 experiments. I think the expected behavior would be that the corner points (0,0), (0,1), (1,0), (1,1) occur 3 times each (which is also the solution that the solver returns me here).
n_experiments = 12
X = pd.DataFrame(
[
[0.0, 0.0],
[1.0, 0.0],
[0.0, 1.0],
[0.0, 0.5],
[0.5, 0.0],
[0.5, 0.5],
[0.25, 0.25],
[0.75, 0.25],
[0.25, 0.75],
[0.75, 0.75],
[1.0, 1.0],
],
columns=["x1", "x2"],
)
formula = formulaic.Formula("x1 + x2")
X = formula.get_model_matrix(X).to_numpy()
n = X.shape[0]
model = ConcreteModel()
model.w = Var(range(n), domain=NonNegativeIntegers, initialize=1, bounds=(0, None))
def compute_XTWX(model):
W = np.diag([model.w[i] for i in range(n)])
XT_W_X = np.dot(X.T, np.dot(W, X)) + 1e-6 * np.eye(X.shape[1])
model.fim = XT_W_X
return model.fim == XT_W_X
def log_det(model):
compute_XTWX(model)
n = model.fim.shape[0]
model.L = np.zeros_like(model.fim)
for i in range(n):
for j in range(i+1):
if i == j:
model.L[i, j] = sqrt(model.fim[i, i] - sum(model.L[i, :i]**2))
else:
model.L[i, j] = (model.fim[i, j] - sum(model.L[i, :j] * model.L[j, :j])) / model.L[j, j]
model.determinant = sum([2 * log(model.L[i, i]) for i in range(n)])
return model.determinant
# Objective function: log(det(X^T W X)) using the approximation
model.obj = Objective(rule=log_det, sense=maximize)
# # Add any constraints (for example, w_i >= 0)
model.Constraint1 = Constraint(expr = sum(model.w[i] for i in range(n)) <= n_experiments)
# Solve the problem
solver = SolverFactory('mindtpy')
solver.solve(model, tee=True, strategy='OA', mip_solver='gurobi', nlp_solver='ipopt',
mip_solver_tee=False,
nlp_solver_tee=False)
# Extract and print the optimal values of w
optimal_w = [model.w[i].value for i in range(n)]
print("Optimal w:", optimal_w)
Cheers Aaron :)
Thank you @Osburg for this elegant solution. In my opinion biting the bullet and taking the relaxed solution for thresholding is the pragmatic choice. In the discrete optimization part, we only use branch and bound b.c. it guarantees a feasible solution. The designs you will yield will all be feasible by design (lol). I expect a small return on invest by introducing branch and bound into your case. I am not aware of an approriate MINLP solver. @R-M-Lee maybe?
Hey,
alright @dlinzner-bcs, thanks for the feedback. Then @rosonaeldred @WaStCo maybe try it out with this code (and pls let me know if it worked for you and whether you think it should be implemented as part of bofire).
This example has 3 decision variable and we assume a fully-quadratic model (i.e. 10 model terms). The candidates are a regular grid covering the design space with 22 points along each dimension (i.e. 12167 candidates --> that's roughly your problem size, right?) and we attempt to choose 100 points from this grid.
from itertools import product
import cvxpy as cp
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from formulaic import Formula
candidates = np.array(list(product(np.linspace(0, 1, 23), repeat=3)), dtype=np.float64)
X = pd.DataFrame(candidates, columns=["x1", "x2", "x3"])
f = Formula("x1*x2 + x1*x3 + x2*x3 + {x1**2} + {x2**2} + {x3**2}")
X = f.get_model_matrix(X).to_numpy()
n_max = 100
x = cp.Variable(X.shape[0]) #weights
objective = cp.Maximize(cp.log_det(X.T @ cp.diag(x) @ X))
constraints = [cp.sum(x) == n_max, x >= 0, x<=1]
prob = cp.Problem(objective, constraints)
result = prob.solve(solver="SCS", verbose=True, canon_backend=cp.SCIPY_CANON_BACKEND)
mask = np.round(x.value,2)
violation = 0.
for i in range(len(mask)):
delta = np.round(mask[i]) - mask[i]
if delta * np.sign(violation) <= 0.:
mask[i] = np.round(mask[i])
else:
mask[i] = np.round(mask[i] - np.sign(violation))
violation = np.sum(mask) - n_max
mask = np.array(mask, dtype=bool)
selected_candidates = candidates[mask, :]
fig = plt.figure(figsize=((10, 8)))
ax = fig.add_subplot(111, projection="3d")
ax.view_init(45, 20)
ax.set_title("Space filling design")
ax.set_xlabel("$x_1$")
ax.set_ylabel("$x_2$")
ax.set_zlabel("$x_3$")
ax.scatter(xs=candidates[:, 0], ys=candidates[:, 1], zs=candidates[:, 2], s=5, color="blue", alpha=0.5)
ax.scatter(xs=selected_candidates[:, 0], ys=selected_candidates[:, 1], zs=selected_candidates[:, 2], s=40, color="red")
This is what the resulting design looks like.
Think the procedure used here to find an integer solution could also be generalized to the general doe problem I described in my last comment.
Cheers Aaron :)
Edit: for the general doe problem just drop the constraint, x <= 1, this then leads to the normal D-optimal design.
I'm trying to track down the exact use case, once I get that, I will come back and comment. I can't tell if the relevant person is working or on leave this week, but should know this afternoon. @Osburg @dlinzner-bcs .
Update: Dominik and I have a meeting with the person (and their anonymized data) Monday, so expect something next week.
Hi @rosonaeldred :) Cool, thx for the update!
@rosonaeldred @dlinzner-bcs Any updates on this?:)
[I will work to be better about this :P My only defense is a confluence of deadlines. ] @Osburg @dlinzner-bcs my thoughts. FYI before you dig in: I'll be under water the rest of the week, so can touch any reply to this again next week at soonest.
The problem as you have solved it is related to what I had in mind, but has more freedom (which is fair based off my description, and also something people ask for, so I also think it should be tackled).
Let me explain my understanding of your formulation, then I'll get to the (much smaller and less exciting) motivation ;).
Recap (A.O.): Discrete specified levels in each decision variable+ choosing (max) amount of experiments to do (so convex)
"Candidates" (A.O)
Quoted from above
This example has 3 decision variable and we assume a fully-quadratic model (i.e. 10 model terms). The candidates are a regular grid covering the design space with 23 points along each dimension (i.e. 12167 candidates --> that's roughly your problem size, right?) and we attempt to choose 100 points from this grid. ...
candidates = np.array(list(product(np.linspace(0, 1, 23), repeat=3)), dtype=np.float64)
Extracting from this, here is how I see your interpretation :
If x_1, x_2, x_3 had been continuous with ranges [0,1], we'd expect a D-Optimal DoE to look like 1 0 0 0 0 1 etc (corners, mid points of the associated box)
You've interpreted candidates to mean the following: the user wants to specify a list in each dimension of allowed/interesting/feasible/to-be-measured values, forcing same length so we get a grid (tell me if there's another reason).
e.g. could just as easily have had
import numpy as np
# Define the three arrays (discrete levels)
x_1 = np.array([0.23, 0.45, 0.71, 0.93])
x_2 = np.array([0.11, 0.3, 0.6, 0.87])
x_3 = np.array([0.38, 0.4, 0.5, 0.68])
# generate grid of all possible combinations, aka. the prescribed space to search
cartesian_product = np.array(list(product(x_1, x_2, x_3)))
This example with 4^3 = 64 combinations is not so big.
Method(s) (A.O)
From there on, this is a D-Optimal design in this discrete space with the relaxed constraint (count experiments max n) so that it is convex.
Source Problem: choose optimally the n experiments to (re-)do, from data given (no new points)
The problem as explained has decision variables calculated descriptors, and each row corresponds to a known product/recipe. To generate data that looks like what is being considered, I'll also pick 3 variables as in your case and use np. random. uniform instead of linear (i'm not going to assume it's well spread out).
"Candidates" : data, maybe correlated(*)
import numpy as np
import pandas as pd
# Generate a DataFrame with 3 columns and 20 rows of random values in the range [0, 1]
num_rows = 20
#note: not a product, result is a (20x3) matrix
df = pd.DataFrame(np.random.uniform(0, 1, (num_rows, 3)), columns=['x1', 'x2', 'x3'])
If want correlated data
generate
- mean: a 3-vect of means in [0,1]
- covar : symmetric positive definite (random) covariance matrix
- use mutlivar_normal to make correlated:
correlated_data = np.random.multivariate_normal(mean, covariance, size=num_rows) - then pull back into spec (range [0,1]^3) with norm.cdf and clip:
from scipy.stats import norm
uniform_data = np.clip(norm.cdf(correlated_data), 0, 1)
assign to df with same columns as before.
Apply modified AO Method to select optimal subset
The only things I changed:
X = df.copy()
candidates = X
n_max = 8
There's a clump that don't seem well represented here to the left/lower left depending on how you look.
Using some anonymized data, 2D problem
Ok, fixed my coding errors. Code is similar, different model (almost fully quadratic, not quite), anonymized data, two decision variables. Left is what the colleague presented as their solution
Thoughts:
It feels kind of clumpy on this type of problem and also the 3D one, and like it's not really hugging the edges as much as I'd like.
Ideas on trying to vary the results:
- change objective to be the I-criterion
- check the different solvers
- ...?
Before diving deeper: checking the pyproject.toml, I see cvxpy[CLARABEL] as an optional deppendency, not sure what would need doing to make use of it for this purpose.
Hi Rosona,
1. Towards your Recap: Absolutetly right. There is only one comment I want to make: I had in mind that what the user means by candidate set is not necessarily a uniform grid but just any set of points in the design space which we want to allow to be included in the final design. I only used a uniform grid in my example to demonstrate the approach is working, because for the regular grid we roughly know what a D-optimal design looks like. So what you did in your example (replacing the regular grid by some arbirtrary candidate set) is perfectly fine and as it was intended to be used.
2. Towards your example:
Your headline Source Problem: For fixed n <=num_rows in data, choose optimally the n experiments to (re-)do implies that at least some of the experiments in the candidate set have already been measured once. Considering already measured candidates when searching for new ones is indeed not implemented in my code snippet (sry, didn't know this was part of your problem to solve). But it will not be difficult to add this feature, we already have this as fixed_experiments in the normal DoEStrategy. I will post an implementation that supports this.
3. Towards the D-optimal design not being optimal for your purposes: I already checked different cvxpy solvers using the regular grid from my example above. There i could not observe any differences between the solutions when I remember correctly. Choosing another criterion might be an option, but it has to be checked if they indeed are convex. If it is not convex one could ofc also try to do the optimization in cyipopt instead of cvxpy. This also touches a question I asked myself before: Would it be useful if linear combinations of objectives would be supported? E.g. in your case an Objective object which is a weighted some of a space filling objective and the normal D-criterion might be helpful. In could implement this somewhen soon-ish, if you say that would be helpful. @jduerholt @dlinzner-bcs @rosonaeldred what do you think?
Cheers Aaron :)
Hi Rosona,
1. Towards your Recap: Absolutetly right. There is only one comment I want to make: I had in mind that what the user means by candidate set is not necessarily a uniform grid but just any set of points in the design space which we want to allow to be included in the final design.
But that's not what you did ("...any set of points"). You separated each x direction by using the product of the x_i's to generate your grid. So the user is specifying the levels of discrete variables x_i.
2. Towards your example: Your headline Source Problem: For fixed n <=num_rows in data, choose optimally the n experiments to (re-)do implies that at least some of the experiments in the candidate set have already been measured once. Considering already measured candidates when searching for new ones is indeed not implemented in my code snippet (sry, didn't know this was part of your problem to solve). But it will not be difficult to add this feature, we already have this as
fixed_experimentsin the normal DoEStrategy. I will post an implementation that supports this.
Sorry, fixed is misleading. n is a number. I am fixing n. I am not fixing a subset of n experiments, just that i know how many i want. I will go back and rewrite
I want to pick n=7 experiments where n is less than the num_rows of data I have and I want them to be optimal in that they are the most spread out. I am forcing the chosen experiments to exist amidst the given data.
3. Towards the D-optimal design not being optimal for your purposes: I already checked different cvxpy solvers using the regular grid from my example above. There i could not observe any differences between the solutions when I remember correctly. Choosing another criterion might be an option, but it has to be checked if they indeed are convex. If it is not convex one could ofc also try to do the optimization in cyipopt instead of cvxpy. This also touches a question I asked myself before: Would it be useful if linear combinations of objectives would be supported? E.g. in your case an Objective object which is a weighted some of a space filling objective and the normal D-criterion might be helpful. In could implement this somewhen soon-ish, if you say that would be helpful. @jduerholt @dlinzner-bcs @rosonaeldred what do you think?
My brain is a bit too tired to answer this. I pass to @dlinzner-bcs or @R-M-Lee
I am a little late to the party here, but let me offer some thoughts.
First, thanks a lot Aaron for the code snippets and the insights. Really good stuff.
One could try solvers like Gurobi for the non-relaxed problem formulation. Maybe HiGHS or SCIP too; those solvers are strong and well supported. Math question: what class of problem do we get when the objective is D-optimality? E.g., with this part of the relaxed problem:
x = cp.Variable(X.shape[0]) #weights
objective = cp.Maximize(cp.log_det(X.T @ cp.diag(x) @ X))
cvxpy separates solvers into these classes:
LP | QP | SOCP | SDP | EXP | POW | MIP and I don't understand which are relevant for the D-optimality problem.I like the candidate approach a lot because of how practical it is for weird constraints (making a set of feasible candidates can be as simple as grid+filter). I do not, however, see a reason to want a linear combination of space-filling and D-optimality. But maybe my creativity is lacking here.
Finally, all the pictures above in this issue seem to use "space-filling" as plot titles, but the code here is for D-optimal designs, not space-filling.