botorch icon indicating copy to clipboard operation
botorch copied to clipboard

[RFC] Move input transforms to GPyTorch

Open saitcakmak opened this issue 3 years ago • 3 comments

Summary: This diff presents a minimal implementation of input transforms in GPyTorch. See cornellius-gp/gpytorch#2114 for GPyTorch side of these changes.

What this does:

  • Moves the transform_inputs from BoTorch Model to GPyTorch GP class, with some modifications to explicitly identify whether given inputs are train or test inputs.
  • Modifies the InputTransform.forward call to use is_training_input argument instead of self.training check to apply the transforms that have transform_on_train=True.
  • Removes preprocess_transform method since this is no-longer needed.
  • For ExactGP models, it transforms both train and test inputs in __call__. For train_inputs it always uses is_training_input=True. For generic inputs, it uses is_training_input=self.training which signals that these are training inputs when the model is in train mode, and that these are test inputs when the model is in eval mode.
  • For ApproximateGP models, it applies the transform to inputs in __call__ using is_training_input=self.training. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transforms inducing_points, thus fixes the previous bug with inducing_points getting transformed in train but not getting transformed in eval. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).
  • For BoTorch SingleTaskVariationalGP, it moves the input_transform attribute down to _SingleTaskVariationalGP, which is the actual ApproximateGP instance. This makes the transform accessible from GPyTorch.

What this doesn't do:

  • It doesn't do anything about DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we make Model inherit from GP, we can keep the existing setup with very minimal changes.
  • It does not clean up the call sites for self.transform_inputs. This is just made into a no-op and the clean-up is left for later.
  • It does not upstream the abstract InputTransform classes to GPyTorch. That'll be done if we decide to go forward with this design.
  • It does not touch PairwiseGP. PairwiseGP has some non-standard use of input transforms, so it needs an audit to make sure things still work fine.
  • I didn't look into ApproximateGP.fantasize. This may need some changes similar to ExactGP.get_fantasy_model.
  • It does not support PyroGP and DeepGP.

Differential Revision: D39147547

saitcakmak avatar Aug 31 '22 21:08 saitcakmak

This pull request was exported from Phabricator. Differential Revision: D39147547

facebook-github-bot avatar Aug 31 '22 21:08 facebook-github-bot

I like this design!

It doesn't do anything about DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we make Model inherit from GP, we can keep the existing setup with very minimal changes.

I think we probably want to have sth like a gpytorch BaseModel that implements the transform handling in a general sense and then have GP inherit from that (so we don't have deterministic models suddenly be GPs...).

It does not upstream the abstract InputTransform classes to GPyTorch. That'll be done if we decide to go forward with this design.

@gpleiss do you have any high-level feedback on the transform setup (https://github.com/pytorch/botorch/tree/main/botorch/models/transforms) that we'd want to incorporate when upstreaming those?

One point that @j-wilson had brought up is that if the transforms are expensive and not learnable (e.g. a pre-fit NN feature extractor) then repeatedly applying it to the same inputs during training (for the full batch case of exact GPs anyway) could be quite wasteful. Is there an elegant solution to this by means of caching the transformed values of the training data and evicting that cache when they are reset?

Balandat avatar Sep 06 '22 23:09 Balandat

@Balandat I really like the botorch API, and this would be super useful to have upstream in GPyTorch!

One point that @j-wilson had brought up is that if the transforms are expensive and not learnable (e.g. a pre-fit NN feature extractor) then repeatedly applying it to the same inputs during training (for the full batch case of exact GPs anyway) could be quite wasteful. Is there an elegant solution to this by means of caching the transformed values of the training data and evicting that cache when they are reset?

There probably is an elegant way to do this, but nothing really comes to mind. We should circle back to this at some point, but at the very least a power user could (e.g.) apply a pre-trained NN to the inputs without using the transforms API.

gpleiss avatar Sep 09 '22 20:09 gpleiss