Ax icon indicating copy to clipboard operation
Ax copied to clipboard

Access ModelBridge Transforms from Acquisition Function

Open akern40 opened this issue 3 years ago • 6 comments

I am currently working on implementing an acquisition function to find contours levels; in order to do so, the acquisition function takes in a set of "integration points" from within the search space and a set of "contour levels" from within the outcome space. This works fine in pure BoTorch, since transformations to the search and outcome space are handled directly in the GP model (i.e., in the Surrogate in Ax). However, as best I can tell, Ax handles its input and outcome transformations in the ModelBridge, not in the Surrogate. As a result, the transformations (which change over each iteration) are known neither to the acquisition function nor to the surrogate model when computing acquisition function values. As a result, there is no way to transform the "integration points" and "contour levels" to match the transformations that occurred during model fitting. Is there any way to easily access these transformations inside of an AcquisitionFunction subclass? Any help would be greatly appreciated.

akern40 avatar May 12 '22 16:05 akern40

Hi, so unfortunately I think in the current setup this might be quite challenging to do. The original design of Ax was, as you observed, to perform all data and parameter transformations in the Modelbridge layer, and not have the models know about these transformations. This works great if you don't need to access them and allows to implement these only once and have them work for a variety of models (e.g. GPs, Trees, what have you). But it does make it less straightfoward to do things like learnable data transforms that you want to backprop through, or the kind of thing you're trying to do here.

I guess short of a broad (and unlikely in the short term) departure from the current paradigm, I can see two potential options here:

  1. Extend the model PI to allow to (optionally) pass some information about the transforms down to the models
  2. Turning off the relevant data transformations in the Modelbridge layer (by defining the transforms to not include the default ones), and then handling the transformation yourself on the BoTorch model side.

Balandat avatar May 14 '22 00:05 Balandat

@bletham, you might find this problem interesting...

eytan avatar May 14 '22 01:05 eytan

Thanks for the responses, I figured that might be the case.

  1. Extend the model PI to allow to (optionally) pass some information about the transforms down to the models

What's the model PI? I'd be happy to try to work on a PR that allows optional keyword arguments to be passed down to the acquisition function's __init__ or forward methods.

  1. Turning off the relevant data transformations in the Modelbridge layer (by defining the transforms to not include the default ones), and then handling the transformation yourself on the BoTorch model side.

I had considered this as well, and for the simple search space I'm using this might be a feasible option.

akern40 avatar May 14 '22 02:05 akern40

What's the model PI?

A typo :) Should be API. You probably want to wait until #967 goes in / start on top of that, unless you really like resolving merge conflicts.

Balandat avatar May 14 '22 04:05 Balandat

Ah ya I like challenges but I'm not a sadist. I'll wait until that PR is complete, and I can take the time to consider some design options to propose for this issue.

akern40 avatar May 14 '22 13:05 akern40

@akern40, the PR is now complete if you wanted to take this on! Let us know if you have any questions or need help. I'm going to put this as "needs repro or more info" for now as we'll be waiting to hear from you on how this goes!

lena-kashtelyan avatar Sep 13 '22 17:09 lena-kashtelyan

Closing as inactive. Please reopen if following up.

lena-kashtelyan avatar Dec 06 '22 18:12 lena-kashtelyan