gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[RFC] Move input transforms to GPyTorch

Open saitcakmak opened this issue 2 years ago • 1 comments

This diff presents a minimal implementation of input transforms in GPyTorch, as requested in #1652. This should be viewed together with pytorch/botorch#1372. The input transforms themselves are currently implemented in https://github.com/pytorch/botorch/blob/cdd668d18b2a7e35bed09b7a2b2fca40e5fd2067/botorch/models/transforms/input.py

What this does:

  • Moves the transform_inputs from BoTorch Model to GPyTorch GP class, with some modifications to explicitly identify whether given inputs are train or test inputs.
  • Modifies the InputTransform.forward call to use is_training_input argument instead of self.training check to apply the transforms that have transform_on_train=True.
  • Removes preprocess_transform method since this is no-longer needed.
  • For ExactGP models, it transforms both train and test inputs in __call__. For train_inputs it always uses is_training_input=True. For generic inputs, it uses is_training_input=self.training which signals that these are training inputs when the model is in train mode, and that these are test inputs when the model is in eval mode.
  • For ApproximateGP models, it applies the transform to inputs in __call__ using is_training_input=self.training. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transforms inducing_points, thus fixes the previous bug with inducing_points getting transformed in train but not getting transformed in eval. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube).
  • For BoTorch SingleTaskVariationalGP, it moves the input_transform attribute down to _SingleTaskVariationalGP, which is the actual ApproximateGP instance. This makes the transform accessible from GPyTorch.

What this doesn't do:

  • It doesn't do anything about DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we make Model inherit from GP, we can keep the existing setup with very minimal changes.
  • It does not clean up the call sites for self.transform_inputs. This is just made into a no-op and the clean-up is left for later.
  • It does not upstream the abstract InputTransform classes to GPyTorch. That'll be done if we decide to go forward with this design.
  • It does not touch PairwiseGP. PairwiseGP has some non-standard use of input transforms, so it needs an audit to make sure things still work fine.
  • I didn't look into ApproximateGP.fantasize. This may need some changes similar to ExactGP.get_fantasy_model.
  • It does not support PyroGP and DeepGP.

saitcakmak avatar Aug 31 '22 21:08 saitcakmak

cc @wjmaddox, @gpleiss, @Balandat

saitcakmak avatar Aug 31 '22 21:08 saitcakmak