lambo icon indicating copy to clipboard operation
lambo copied to clipboard

Correcting Acquisition Function Calculation

Open badeok0716 opened this issue 2 years ago • 3 comments

Please refer to Issue #14

badeok0716 avatar Jan 04 '24 23:01 badeok0716

To directly show issue in the original implementation,

I printed three values

  • acq_fn(best_seqs[None,:]),
  • acq_fn(best_seqs[..., None]).mean(),
  • acq_fn(best_seqs[..., None][-1]

at each iteration in the original implementation optimizers/lambo.py#L334-336 as follows:

                    with torch.no_grad():
                        batch_acq_val = acq_fn(best_seqs[..., None]) 
                        print(batch_acq_val.shape)
                        print("acq_fn(best_seqs[..., None]).mean().item()\n", batch_acq_val.mean().item())
                        print("acq_fn(best_seqs[..., None])[-1].item()\n", batch_acq_val[-1].item())

                        batch_acq_val = acq_fn(best_seqs[None, :])
                        print(batch_acq_val.shape)
                        print("acq_fn(best_seqs[None, :]).mean().item()\n", batch_acq_val.mean().item())
                        print()
                    curr_score = -1.0 * batch_acq_val```

I ran this command python scripts/black_box_opt.py optimizer=lambo optimizer.encoder_obj=mlm task=regex tokenizer=protein surrogate=multi_task_exact_gp acquisition=nehvi acquisition.num_samples=100 and got the following results:


---- loading checkpoint from epoch 256 ----
|    |    epoch |   train_loss |   val_nll |   val_rmse |   val_s_rho |   val_ece |   val_occ_diff |   val_post_var |   noise |   lengthscale |   test_nll |   test_rmse |   test_s_rho |   test_ece |   test_occ_diff |   test_post_var |   val_perplexity |   test_perplexity |   best_score |   best_epoch |   best_loss |   best_loss_epoch |
|---:|---------:|-------------:|----------:|-----------:|------------:|----------:|---------------:|---------------:|--------:|--------------:|-----------:|------------:|-------------:|-----------:|----------------:|----------------:|-----------------:|------------------:|-------------:|-------------:|------------:|------------------:|
|  0 | 256.0000 |       1.9706 |   -0.1337 |     0.2325 |      0.7605 |    0.2407 |        -0.2407 |         0.0118 |  0.0612 |        0.7001 |    -0.3408 |      0.1399 |       0.6226 |     0.3415 |         -0.3415 |          0.0052 |          33.0809 |           42.0368 |      -0.1337 |     256.0000 |      1.9706 |          256.0000 |

---- optimizing candidates ----
torch.Size([16])
acq_fn(best_seqs[..., None]).mean().item()
 0.03003895409041333
acq_fn(best_seqs[..., None])[-1].item()
 0.002817928563987497
torch.Size([1])
acq_fn(best_seqs[None, :]).mean().item()
 0.003015315650346108

torch.Size([16])
acq_fn(best_seqs[..., None]).mean().item()
 0.030456435381897233
acq_fn(best_seqs[..., None])[-1].item()
 0.002736823933254607
torch.Size([1])
acq_fn(best_seqs[None, :]).mean().item()
 0.003122292317222965

torch.Size([16])
acq_fn(best_seqs[..., None]).mean().item()
 0.02859126123072568
acq_fn(best_seqs[..., None])[-1].item()
 0.002836023691358241
torch.Size([1])
acq_fn(best_seqs[None, :]).mean().item()
 0.00325981464099024

...

acq_fn(best_seqs[None, :]).mean().item() computes 1-NEHVI of best_seqs[-1] (=acq_fn(best_seqs[..., None])[-1]), not the average value of 1-NEHVI values of best_seqs (=acq_fn(best_seqs[..., None]).mean().item()).

(I used torch.dtype=torch.double in the machine with RTX3090 GPU.)

badeok0716 avatar Jan 05 '24 05:01 badeok0716

I re-read the original code based on the case of using EHVI instead of NEHVI.

According to the implementation of EHVI, I found out that

  • acq_fn(seqs[None,:]) computes q-EHVI,
  • acq_fn(seqs[...,None]) computes q number of 1-EHVIs of seqs[i]s for i=1, ..., q. where seqs is a numpy array of strings with the shape [q].

The recent commit is another version that is aligned with the EHVI.

badeok0716 avatar Jan 06 '24 03:01 badeok0716

Here, I summarize three issues which i fixed in this PR. Please refer to the changes in https://github.com/samuelstanton/lambo/compare/main...badeok0716:lambo-fix:main.

  1. Error in q-NEHVI calculation q=X.shape[-2] in lambo/acquisitions/monte_carlo.py#L69-L89 leads to wrong calculation in q-NEHVI. I fixed it as fixed lambo/acquisitions/monte_carlo.py#L69-L90

  2. Error in apply_mutation Previous code for the apply_mutation function in lambo/utils.py wrongly applies mutation. Please refer to https://github.com/samuelstanton/lambo/commit/265e918ad576e1fb190fdb6da1e329c2cc18a4b6. (Since tokenizer.decode() do not generate bos nor eos in this repo's implementation, tokens = tokenizer.decode(tokenizer.encode(base_seq)).split(" ")[1:-1] in lambo/utils.py deletes tokens at both ends, e.g., ACTGCCG -> CTGCC.)

  3. Discrepancy in LaMBO algorithm between paper and Github The LaMBO algorithm implemented in the recent commit of this repo is different from the paper's implementation. I replaced the algorithm with the older version in this commit which is same to the paper's algorithm.

The following plot compares the performance of the fixed version and released results in https://github.com/samuelstanton/lambo/blob/main/notebooks/plot_hypervolume.ipynb.

스크린샷 2024-01-11 오전 1 27 47

After correcting two issues (q-NEHVI calculation, apply_mutation), the performance of GA-based methods hugely increased in both tasks.

Note that I ran the following commands

  • NSGAII: python scripts/black_box_opt.py optimizer=mf_genetic optimizer/algorithm=nsga2 task=[task] tokenizer=[tokenizer] trial_id=[trial_id]
  • MBGA: python scripts/black_box_opt.py optimizer=mb_genetic optimizer/algorithm=soga optimizer.encoder_obj=mll task=[task] tokenizer=[tokenizer] surrogate=multi_task_exact_gp acquisition=nehvi trial_id=[trial_id]
  • LaMBO: python scripts/black_box_opt.py optimizer=lambo optimizer.encoder_obj=mlm task=[task] tokenizer=[tokenizer] surrogate=multi_task_exact_gp acquisition=nehvi trial_id=[trial_id]

with (task, tokenizer) = (regex, protein) or (chem, selfies) for 10 trials_ids, 1, ..., 10.

badeok0716 avatar Jan 11 '24 06:01 badeok0716

Thanks again for your work tracking this down, merged into master and make a note in the README

samuelstanton avatar Apr 20 '24 05:04 samuelstanton