botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[RFC] Refactor Input Transforms

Open saitcakmak opened this issue 3 years ago • 3 comments

Summary: Currently, we apply the input transforms in train mode at the forward call, and in eval model at the posterior call. We also use a transform_train_inputs call at the eval/train calls to make sure that at eval time the train_inputs are stored as transformed (since they don't pass through posterior). This design supports ExactGP models, and supports specifying where to apply which input transform via the flags (so that one-to-many transforms are only applied to test inputs). However, this does not work great with Approximate GP models, since this setup does not transform the inducing points at eval time.

This refactor splits out one-to-many transforms as InputAugmentationTransform, allowing us to revert to simply applying the transform_inputs in the forward pass (at all times). We still need to apply one-to-many transforms (now called InputAugmentationTransform) in posterior, so we introduce an augment_inputs method. (Inspired by the public-private APIs of Ax) In order to minimize the transform related knowledge expected from developers, this introduces a Model.forward call that applies transform_inputs and calls self._forward. <AnyGivenModel>._forward is the usual forward call that computes the prior, except that it no longer has to worry about transforms. Similarly, for the posterior, this makes Model.posterior into a simple wrapper around Model._posterior, which applies the augment_inputs call and the posterior_transform. Again, the <AnyGivenModel>._posterior becomes the usual posterior call that no longer has to worry about the input or posterior transforms (still has to deal with the outcome transform in the current implementation, though we can fix this by bringing back the fantasize flag).

This diff presents a minimal implementation around the SingleTaskGP model.

Differential Revision: D35129407

saitcakmak avatar Apr 14 '22 15:04 saitcakmak

This pull request was exported from Phabricator. Differential Revision: D35129407

facebook-github-bot avatar Apr 14 '22 15:04 facebook-github-bot

cc @wjmaddox. For context, this came out of a discussion around the input transforms and variational strategy / inducing points. The current "apply only in posterior in eval mode" skips over the inducing points when evaluating the posterior (we pre-transform the train_inputs on the model.eval() call but not the inducing points).

saitcakmak avatar Apr 14 '22 16:04 saitcakmak

This looks great! Yeah, I really struggled with input transforms with variational GPs (don't think the version in Botorch really handles them super well now) and had to place them in the forwards call for my own research code. This seems like a pretty sensible structure to dichotomize 1-1 transforms with 1-many transforms too.

wjmaddox avatar Apr 14 '22 16:04 wjmaddox

Closed in favor of #1372

saitcakmak avatar Oct 05 '22 23:10 saitcakmak