botorch
botorch copied to clipboard
[Feature Request] Minibatch training when mll is an _ApproximateMarginalLogLikelihood
🚀 Feature Request
Now that fit_gpytorch_mll
exists using multiple dispatch, it seems like it'd be fairly straightforward to support minibatch training by registering a fit_gpytorch_torch_stochastic
or similar as the optimizer for _ApproximateMarginalLogLikelihood
mlls.
Motivation
Is your feature request related to a problem? Please describe.
As far as I can tell browsing the code, running fit_gpytorch_mll
on an ApproximateGPyTorchModel
would just use full batch training. As a result, we have (e.g., for latent space optimization tasks) typically been brewing our own GPyTorch models + training code still, despite the existence of ApproximateGPyTorchModel
. We're planning on submitting a PR with a latent space bayesopt tutorial, but I'd like it to be more BoTorch-y than it currently is -- right now the actual model handling is entirely outside of BoTorch for this reason.
Pitch
Describe the solution you'd like
- [x] Write
fit_gpytorch_torch_stochastic
inbotorch.optim.fit
that does minibatch training with a user specified batch size. For now, I was thinking this can just make a standardDataLoader
over the train targets and inputs -- handling the case wheretrain_targets
is actually a tuple might require more thought if we wanted to support that out of the gate.maxiter
in the stopping critereon would refer to a number of epochs of training. - [x] Register
fit_gpytorch_torch_stochastic
as the default optimizer via a_fit_approximategp_stochastic
inbotorch.fit
to the dispatcher for(_ApproximateMarginalLogLikelihood, Likelihood, ApproximateGPyTorchModel)
. - [ ] (Possibly breaking) As described above, this would leave it to the user to decide to do full batch optimization, either by specifying
fit_gpytorch_torch
manually as the optimizer or (equivalently with negligible overhead) specifying the batch size to be the full N. One solution might be to just call the fallback fit if a minibatch size / optimizer isn't specified by the user? On the other hand, in the long run, it probably makes sense to assume by default that the user wants to do some kind of stochastic optimization if they're going to the trouble of using a variational approximate GP specifically rather than just e.g. an inducing point kernel on anExactGP
.
Are you willing to open a pull request? (See CONTRIBUTING) Yeah
This is a great suggestion, and IIRC @j-wilson has played around with this a bit before. Not sure what state that is in and whether it makes sense for him to push out a draft of this, or whether it's better to just start fresh with a PR on your end (seems reasonably straightforward all in all). @j-wilson any thoughts here?
One solution might be to just call the fallback fit if a minibatch size / optimizer isn't specified by the user? On the other hand, in the long run, it probably makes sense to assume by default that the user wants to do some kind of stochastic optimization if they're going to the trouble of using a variational approximate GP specifically rather than just e.g. an inducing point kernel on an ExactGP.
Yeah that makes sense to me. I think if the user is using variational approximate GP models we can assume that they'd be able to manually specify the full batch training optimizer if needed. Another option would be to parse this somehow from the kwargs, but I don't think we need to worry about this for now.
Here's a rough draft of what this might look like: https://github.com/pytorch/botorch/compare/main...jacobrgardner:botorch:stochastic_fitting
The high level fitting works great (works on a piece of code I've been testing as well as our homebrew model fitting). Still a few TODOs even before code review:
- [ ] I made a quick
GPyTorchDataset
to let us handletrain_inputs
as a tuple. Probably shouldn't live inline, but wasn't sure where you'd want it. - [ ] Still need to decide how to allow for the full batch special case. Can either make
optimizer
an arg, or makeminibatch_size
default toNone
and do full batch if it's not specified at all. - [ ] Unit tests, doc updates
@j-wilson @Balandat just let me know if you all don't have something further along already than this and I can open as a PR to track the TODOs there.
(Edit: Oops, some automated thing must have run black on the files before commiting, sorry about the irrelevant parts of the linked diff)
IIRC, @mshvartsman et al. use ApproximateGP with full-batch training. cc'ing in case they have any input on this.
@jacobrgardner Hi Jake. Fully on board with you here. As Max mentioned, I've put together a draft for this as well. At a glance, it looks pretty similar to your implementation.
The main difference seems to be that I actually just rewrote fit_gpytorch_torch
to be more generic instead of introducing a separate method. This isn't necessarily a better way of doing things; I just don't like the current fit_gpytorch_torch
method...
Aside from that, I have data_loader
as an optional argument, with the method defaulting to full-batch. Under this approach, the responsibility of constructing data_loader
is off-loaded to the fit_gpytorch_mll
subroutine. This same subroutine would be also responsible for throwing an MDNotImplementedError
in cases where the amount of training data is sufficiently small for spicy.optimize
to be the preferred optimizer.
Would something like this work for your use cases?
Regrading GPyTorchDataset
, I'm not sure I understand the need for this class. How about:
dataset = TensorDataset(*model.train_inputs, model.train_targets)
data_loader = DataLoader(dataset, **kwargs)
for batch_idx, (*inputs, targets) in enumerate(data_loader):
# do stuff
If we end up with cases where train_targets
is also Tuple[Tensor, ...]
, we'd need to update the (*input, targets)
bit, but this seems doable?
@j-wilson Ah, yeah looks like we can just use TensorDataset there.
In terms of the rest, how would do you envision the user specifying to use minibatch training? Would the idea be to do something like fit_gpytorch_mll(mll, my_data_loader)
, overriding the use of model.train_inputs
entirely? Or would I specify the minibatch size, fit_gpytorch_mll
would do some typechecking to make sure I'm using Approximate*
, and then make a data loader?
I guess I'm personally fine with essentially any of the proposed interfaces here.
@jacobrgardner Good questions.
I hadn't actually considered a solution like fit_gpytorch_mll(mll, data_loader)
. I really like this API, but fear it may be too heavy-duty for simple use cases.
My thought had been to make a create_data_loader(model, **kwargs) -> DataLoader
helper that abstracts away DataLoader construction. We would then add data_loader: Union[DataLoader, Dict[str, Any]]
as a keyword to the MD subroutine, which would internally call create_data_loader(mll.model, **data_loader)
when data_loader
is passed as a dict.
A typical call pattern might then look something like:
fit_gpytorch_mll(mll, data_loader={"batch_size": 128})
@j-wilson Okay, so if you all think a rewrite of fit_gpytorch_torch
is warranted, maybe the right solution here is something in the middle, where we add a top level _fit_approximategp_stochastic
because then we can use the dispatcher to typecheck that the mll is an ApproximateMLL and the model is an approximate model.
Then, both _fit_approximategp_stochastic
and _fit_fallback
end up calling fit_gpytorch_torch
(or scipy for the latter), but _fit_approximategp_stochastic
enables batch size < N functionality, while _fit_fallback
throws a warning if batch_size is user specified < N, with the warning saying the types didn't match well enough for the dispatcher, so we're doing full batch?
@jacobrgardner Up for discussion. A naive implementation would probably see data_loader
as a _fit_approximategp_stochastic
-specific keyword argument that gets ignored by other fit_gpytorch_mll
subroutines.
it probably makes sense to assume by default that the user wants to do some kind of stochastic optimization if they're going to the trouble of using a variational approximate GP specifically
Not sure if that's a safe assumption :). As @saitcakmak said, in AEPsych we pretty much exclusively do full-batch small-data fit_gpytorch_model
with ApproximateGP
(mainly here https://github.com/facebookresearch/aepsych/blob/main/aepsych/models/base.py#L423). Most of what we use VI for is non-Gaussian likelihoods (bernoulli, categorical etc), not big data. It's not a huge deal for us to change our calls but in my experience, the scipy optimizer is dramatically faster when it works, and I wouldn't want to default for new users to be some flavor of SGD on small data. I don't think when I started using gpytorch/botorch I would've known to switch optimizers for our setting.
So my vote would be to either [a] retain the full batch with SAA default and warn if the data is too large, or [b] have a sensible user-adjustable cutoff to switch between the fitting strategies (similarly to how gpytorch switches between cholesky and CG for solves and logdets). I think I'd prefer [b] over [a], we'd just need to tune the cutoff.
Hi folks. I've put together a PR (#1439) that implements the above. This ended up being a larger change than I had originally anticipated, but hopefully people will agree that these are the "right" changes (or at least trending in that direction).
The best course of action in terms of balancing the specific functionality requested here with the overall design seemed to be to introduce a loss closure abstraction. This allows us to abstract away things like DataLoaders, while also enabling the user to specify custom routines for evaluating their loss functions.
I haven't tested this yet, but I'm hopeful that we'll be able to use e.g. torch.jit.script
to compile these closures and expedite training.
I'm hopeful that we'll be able to use e.g. torch.jit.script to compile these closures and expedite training.
Sam has been having some good success using torchdynamo/torchinductor, would be interesting to see what this does here.