gpytorch
gpytorch copied to clipboard
[RFC] Move input transforms to GPyTorch
This diff presents a minimal implementation of input transforms in GPyTorch, as requested in #1652. This should be viewed together with pytorch/botorch#1372. The input transforms themselves are currently implemented in https://github.com/pytorch/botorch/blob/cdd668d18b2a7e35bed09b7a2b2fca40e5fd2067/botorch/models/transforms/input.py
What this does:
- Moves the
transform_inputs
from BoTorchModel
to GPyTorchGP
class, with some modifications to explicitly identify whether given inputs are train or test inputs. - Modifies the
InputTransform.forward
call to useis_training_input
argument instead ofself.training
check to apply the transforms that havetransform_on_train=True
. - Removes
preprocess_transform
method since this is no-longer needed. - For
ExactGP
models, it transforms both train and test inputs in__call__
. Fortrain_inputs
it always usesis_training_input=True
. For genericinputs
, it usesis_training_input=self.training
which signals that these are training inputs when the model is intrain
mode, and that these are test inputs when the model is ineval
mode. - For
ApproximateGP
models, it applies the transform toinputs
in__call__
usingis_training_input=self.training
. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transformsinducing_points
, thus fixes the previous bug withinducing_points
getting transformed intrain
but not getting transformed ineval
. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube). - For BoTorch
SingleTaskVariationalGP
, it moves theinput_transform
attribute down to_SingleTaskVariationalGP
, which is the actualApproximateGP
instance. This makes the transform accessible from GPyTorch.
What this doesn't do:
- It doesn't do anything about
DeterministicModel
s. Those will still need to deal with their own transforms, which is not implemented here. If we makeModel
inherit fromGP
, we can keep the existing setup with very minimal changes. - It does not clean up the call sites for
self.transform_inputs
. This is just made into a no-op and the clean-up is left for later. - It does not upstream the abstract
InputTransform
classes to GPyTorch. That'll be done if we decide to go forward with this design. - It does not touch
PairwiseGP
.PairwiseGP
has some non-standard use of input transforms, so it needs an audit to make sure things still work fine. - I didn't look into
ApproximateGP.fantasize
. This may need some changes similar toExactGP.get_fantasy_model
. - It does not support
PyroGP
andDeepGP
.
cc @wjmaddox, @gpleiss, @Balandat