gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

[Feature Request] gpytorch on mps

Open florenzi002 opened this issue 3 years ago • 5 comments

🚀 Feature Request

Enabling gpytorch to run on Apple Silicon mps devices

Motivation

Apple silicon is now on the scene for a while and looks like it is going to stay, and many devs have that as a primary device. Local development on such devices could benefit from gpytorch to be able to run on maps devices and not only cuda

Pitch

we could replace hardcoded calls to cuda methods in PyTorch with equivalents in the mps module. https://pytorch.org/docs/stable/notes/mps.html

Am not sure about the extent of rewriting involved if any is required though.

Additional context

florenzi002 avatar Nov 30 '22 11:11 florenzi002

That would be great. Cc @sdaulton

Balandat avatar Nov 30 '22 21:11 Balandat

I don't have experience with this particular codebase but I'd be happy to contribute if this is something that is of interest for the community

florenzi002 avatar Dec 01 '22 13:12 florenzi002

It would definitely be of interest!

Balandat avatar Dec 01 '22 21:12 Balandat

+1 on this - it'll be very helpful!

yw5aj avatar Feb 05 '23 23:02 yw5aj

I am interested in contributing to this!

Instead of using train_x = train_x.cuda(), using the following seems to work for me: gpu = torch.device("mps:0") train_x = train_x.to(gpu, dtype=torch.float32)

However, when I try this on Simple_GP_Regression_CUDA.ipynb, I get the following error:

File ~/Developer/gpytorch/gpytorch/mlls/exact_marginal_log_likelihood.py:64, in ExactMarginalLogLikelihood.forward(self, function_dist, target, *params) 62 # Get the log prob of the marginal distribution 63 output = self.likelihood(function_dist, *params) ---> 64 res = output.log_prob(target) 65 res = self._add_other_terms(res, params) 67 # Scale by the amount of data we have ... ---> 20 L, info = torch.linalg.cholesky_ex(A, out=out) 21 if settings.trace_mode.on() or not torch.any(info): 22 return L

NotImplementedError: The operator 'aten::linalg_cholesky_ex.L' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

So it seems like the Cholesky operator does not have an MPS backend implementation in PyTorch yet?

raghuramshankar avatar Apr 11 '23 23:04 raghuramshankar