botorch
botorch copied to clipboard
[RFC] Move input transforms to GPyTorch
Summary: This diff presents a minimal implementation of input transforms in GPyTorch. See cornellius-gp/gpytorch#2114 for GPyTorch side of these changes.
What this does:
- Moves the
transform_inputsfrom BoTorchModelto GPyTorchGPclass, with some modifications to explicitly identify whether given inputs are train or test inputs. - Modifies the
InputTransform.forwardcall to useis_training_inputargument instead ofself.trainingcheck to apply the transforms that havetransform_on_train=True. - Removes
preprocess_transformmethod since this is no-longer needed. - For
ExactGPmodels, it transforms both train and test inputs in__call__. Fortrain_inputsit always usesis_training_input=True. For genericinputs, it usesis_training_input=self.trainingwhich signals that these are training inputs when the model is intrainmode, and that these are test inputs when the model is inevalmode. - For
ApproximateGPmodels, it applies the transform toinputsin__call__usingis_training_input=self.training. This again signifies whether the given inputs are train or test inputs based on the mode of the model. Note that this NEVER transformsinducing_points, thus fixes the previous bug withinducing_pointsgetting transformed intrainbut not getting transformed ineval. It is expected that the user will define inducing points in the appropriate space (mostly the normalized space / unit cube). - For BoTorch
SingleTaskVariationalGP, it moves theinput_transformattribute down to_SingleTaskVariationalGP, which is the actualApproximateGPinstance. This makes the transform accessible from GPyTorch.
What this doesn't do:
- It doesn't do anything about
DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we makeModelinherit fromGP, we can keep the existing setup with very minimal changes. - It does not clean up the call sites for
self.transform_inputs. This is just made into a no-op and the clean-up is left for later. - It does not upstream the abstract
InputTransformclasses to GPyTorch. That'll be done if we decide to go forward with this design. - It does not touch
PairwiseGP.PairwiseGPhas some non-standard use of input transforms, so it needs an audit to make sure things still work fine. - I didn't look into
ApproximateGP.fantasize. This may need some changes similar toExactGP.get_fantasy_model. - It does not support
PyroGPandDeepGP.
Differential Revision: D39147547
This pull request was exported from Phabricator. Differential Revision: D39147547
I like this design!
It doesn't do anything about DeterministicModels. Those will still need to deal with their own transforms, which is not implemented here. If we make Model inherit from GP, we can keep the existing setup with very minimal changes.
I think we probably want to have sth like a gpytorch BaseModel that implements the transform handling in a general sense and then have GP inherit from that (so we don't have deterministic models suddenly be GPs...).
It does not upstream the abstract InputTransform classes to GPyTorch. That'll be done if we decide to go forward with this design.
@gpleiss do you have any high-level feedback on the transform setup (https://github.com/pytorch/botorch/tree/main/botorch/models/transforms) that we'd want to incorporate when upstreaming those?
One point that @j-wilson had brought up is that if the transforms are expensive and not learnable (e.g. a pre-fit NN feature extractor) then repeatedly applying it to the same inputs during training (for the full batch case of exact GPs anyway) could be quite wasteful. Is there an elegant solution to this by means of caching the transformed values of the training data and evicting that cache when they are reset?
@Balandat I really like the botorch API, and this would be super useful to have upstream in GPyTorch!
One point that @j-wilson had brought up is that if the transforms are expensive and not learnable (e.g. a pre-fit NN feature extractor) then repeatedly applying it to the same inputs during training (for the full batch case of exact GPs anyway) could be quite wasteful. Is there an elegant solution to this by means of caching the transformed values of the training data and evicting that cache when they are reset?
There probably is an elegant way to do this, but nothing really comes to mind. We should circle back to this at some point, but at the very least a power user could (e.g.) apply a pre-trained NN to the inputs without using the transforms API.