botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[Feature Request]Add POI to Multi-Objective Analytic Acquisition Functions

Open Ruan-Yixiang opened this issue 2 years ago • 12 comments

🚀 Feature Request

Motivation

I want to employ the Multi-Objective Analytic Acquisition Function: probability of improvement (POI) (mentioned in Yang2019) in my research work. Yang, K., Emmerich, M., Deutz, A. et al. Efficient computation of expected hypervolume improvement using box decomposition algorithms. J Glob Optim 75, 3–34 (2019)

Pitch

Botorch's Multi-Objective Analytic Acquisition Functions module contains the Expected Hypervolume Improvement developed in Yang2019. Can the POI that mentioned in Yang2019 be added to Multi-Objective Analytic Acquisition Functions module. Thanks!

Ruan-Yixiang avatar Jan 17 '23 13:01 Ruan-Yixiang

@sdaulton did you ever look into implementing MOO PoI? For non-batch (q=1) versions this should not be too bad (potentially using some of @j-wilson's MVN CDF work). MC sampling will be straightforward but we'd have to worry about the gradients.

Balandat avatar Jan 17 '23 15:01 Balandat

@sdaulton did you ever look into implementing MOO PoI? For non-batch (q=1) versions this should not be too bad (potentially using some of @j-wilson's MVN CDF work). MC sampling will be straightforward but we'd have to worry about the gradients. Does "MOO PoI" mean Multi-Objective Optimization PoI?If it does, this is what I want. How can I implement in botorh (just q=1)? I don't seem to find this acquisition function.

Ruan-Yixiang avatar Jan 17 '23 16:01 Ruan-Yixiang

@Ruan-Yixiang please see the discussion in https://github.com/pytorch/botorch/discussions/1423. It includes an implementation of MC PoI for multi-objective case.

saitcakmak avatar Jan 17 '23 17:01 saitcakmak

@Ruan-Yixiang please see the discussion in #1423. It includes an implementation of MC PoI for multi-objective case.

Thank for your reply. I see the discussion. He wants to use product of each object PI as the acquisition function to generate next candidate. But I want to calculate the probability of improvement of hypervolume as the stopping criterion of MOO (like this paper https://arxiv.org/abs/1511.07827). The discussion in #1423 may not solve my problem.

Ruan-Yixiang avatar Jan 18 '23 04:01 Ruan-Yixiang

Oh, I see. In that case, you probably want to do something like replacing the mean with a differentiable approximation of the probability (e.g. sum of sigmoids as in the PoI code) in https://github.com/pytorch/botorch/blob/main/botorch/acquisition/multi_objective/monte_carlo.py#L353

If I am not mistaken, the areas_per_segment.sum(dim=-1) gives you the HVI for each MC sample. Something like this should get you the probability of HVI:

hvi_per_sample = areas_per_segment.sum(dim=-1)
p_of_hvi = torch.sigmoid(hvi_per_sample / self.eta).mean(dim=0)

saitcakmak avatar Jan 18 '23 04:01 saitcakmak

@Ruan-Yixiang I recommend you try using MVNXPB, which is a state-of-the-art approximator for Gaussian probabilities of the form P(a < f(X) < b). In the case of qPI, this would look something like:

# Compute posterior
posterior = self.model.posterior(
    X=X, posterior_transform=self.posterior_transform
)

# Define lower and upper bounds
bounds = torch.nn.functional.pad(input=best_f - posterior.mean, pad=(1, 0) value=-float("inf"))

# Evaluate log P(f(X) < best_f)
log_prob = MVNXPB(posterior.covariance_matrix, bounds=bounds).solve()

# Return 1 - P(f(X) < best_f)
prob_improvement = log_prob.expm1().neg()

j-wilson avatar Jan 18 '23 16:01 j-wilson

Oh, I see. In that case, you probably want to do something like replacing the mean with a differentiable approximation of the probability (e.g. sum of sigmoids as in the PoI code) in https://github.com/pytorch/botorch/blob/main/botorch/acquisition/multi_objective/monte_carlo.py#L353

If I am not mistaken, the areas_per_segment.sum(dim=-1) gives you the HVI for each MC sample. Something like this should get you the probability of HVI:

hvi_per_sample = areas_per_segment.sum(dim=-1)
p_of_hvi = torch.sigmoid(hvi_per_sample / self.eta).mean(dim=0)

Thank you for your reply. I tried according to your idea. In MOO, HVI is always ≥ 0, so

p_of_hvi = torch.sigmoid(hvi_per_sample / self.eta).mean(dim=0)

p_of_hvi is always≥ 0.5(>0.75 in most cases), the POI cannot be the stopping criterion of MOO (when POI < a tiny number). So I decide to use the proportion greater than a tiny number in the HVI for MC samples as the probability of HVI. The code like this:

hvi_per_sample = self._compute_qehvi(samples=samples)
n_sample = hvi_per_sample.size()[0]
p = 0
f>or i in hvi_per_sample:
        if i > 5e-3:
        p += 1
poi = torch.tensor(p/n_sample)

The tiny number (5e-3 in the code) is like trade-off paremeter in SOO's POI.

image Do you think this is reasonable? It seems to be effective at present.

Ruan-Yixiang avatar Jan 26 '23 16:01 Ruan-Yixiang

@Ruan-Yixiang I recommend you try using MVNXPB, which is a state-of-the-art approximator for Gaussian probabilities of the form P(a < f(X) < b). In the case of qPI, this would look something like:

# Compute posterior
posterior = self.model.posterior(
    X=X, posterior_transform=self.posterior_transform
)

# Define lower and upper bounds
bounds = torch.nn.functional.pad(input=best_f - posterior.mean, pad=(1, 0) value=-float("inf"))

# Evaluate log P(f(X) < best_f)
log_prob = MVNXPB(posterior.covariance_matrix, bounds=bounds).solve()

# Return 1 - P(f(X) < best_f)
prob_improvement = log_prob.expm1().neg()

Thank you for the reply. I will try this soon.

Ruan-Yixiang avatar Jan 26 '23 16:01 Ruan-Yixiang

@Ruan-Yixiang were you able to try this? Would be great if you could share your findings here.

Balandat avatar Feb 11 '23 20:02 Balandat

I use the proportion greater than a tiny number (xi) in the HVI for MC samples as the probability of HVI. The code like this:

     self._compute_qehvi = areas_per_segment.sum(dim=-1)  # compute the HVI for each MC sample
    @concatenate_pending_points
    @t_batch_mode_transform()
    def forward(self, X: Tensor) -> Tensor:
        posterior = self.model.posterior(X)
        samples = self.sampler(posterior)
        hvi_per_sample = self._compute_qehvi(samples=samples)  # compute the HVI for each MC sample
        n_sample = hvi_per_sample.size()[0]  # the number of MC sample
        hvi_per_sample = torch.tensor(hvi_per_sample - self.xi).clamp_min(0)  # filter the tiny positive HVI
        sample_impr = torch.nonzero(hvi_per_sample).size()[0]
        # the proportion greater than a tiny number (xi) in the HVI for MC samples
        return torch.tensor(sample_impr / n_sample)   # p_of_hvi

The xi is equal the hypervolume/1000 here.

self.xi = partitioning.compute_hypervolume()/1000

From the current results, the code is effective.

Ruan-Yixiang avatar Feb 13 '23 12:02 Ruan-Yixiang

Late to the game here, but implementing analytic PoI under the common assumption of independent objectives would also be easy to do

sdaulton avatar Feb 13 '23 15:02 sdaulton

FYI here is a (non-MOO) implementation of qPI: #1684

Balandat avatar Feb 16 '23 18:02 Balandat

Closing this as inactive. If there is still interest in implementing this, we'd welcome a PR into botorch_community.

saitcakmak avatar Jul 24 '24 18:07 saitcakmak