[Feature Request] gpytorch on mps
🚀 Feature Request
Enabling gpytorch to run on Apple Silicon mps devices
Motivation
Apple silicon is now on the scene for a while and looks like it is going to stay, and many devs have that as a primary device. Local development on such devices could benefit from gpytorch to be able to run on maps devices and not only cuda
Pitch
we could replace hardcoded calls to cuda methods in PyTorch with equivalents in the mps module. https://pytorch.org/docs/stable/notes/mps.html
Am not sure about the extent of rewriting involved if any is required though.
Additional context
That would be great. Cc @sdaulton
I don't have experience with this particular codebase but I'd be happy to contribute if this is something that is of interest for the community
It would definitely be of interest!
+1 on this - it'll be very helpful!
I am interested in contributing to this!
Instead of using train_x = train_x.cuda(),
using the following seems to work for me:
gpu = torch.device("mps:0")
train_x = train_x.to(gpu, dtype=torch.float32)
However, when I try this on Simple_GP_Regression_CUDA.ipynb, I get the following error:
File ~/Developer/gpytorch/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params) 62 # Get the log prob of the marginal distribution 63 output = self.likelihood(function_dist, *params) ---> 64 res = output.log_prob(target) 65 res = self._add_other_terms(res, params) 67 # Scale by the amount of data we have ... ---> 20 L, info = torch.linalg.cholesky_ex(A, out=out) 21 if settings.trace_mode.on() or not torch.any(info): 22 return L
NotImplementedError: The operator 'aten::linalg_cholesky_ex.L' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
So it seems like the Cholesky operator does not have an MPS backend implementation in PyTorch yet?