gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

SGPR's speed is abnormally slow when botorch is imported

Open pipme opened this issue 2 years ago • 6 comments

🐛 Bug

For context, I am trying this tutorial notebook for SGPR https://docs.gpytorch.ai/en/latest/examples/02_Scalable_Exact_GPs/SGPR_Regression_CUDA.html and would like to use L-BFGS to optimize it. Therefore,from botorch import fit_gpytorch_mll is added.

Botorch has some modifications on gpytorch's default settings for bringing the computations back to the usual exact GP with Cholesky decomposition. However, the speed of SGPR is incredibly slow even with, say, 30 inducing points.

This bug may be related to #1709.

To reproduce

import math
import os
import urllib.request
from math import floor

import gpytorch
import torch
import tqdm.notebook as tqdm
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import InducingPointKernel, RBFKernel, ScaleKernel
from gpytorch.means import ConstantMean
from matplotlib import pyplot as plt
from scipy.io import loadmat


data = torch.Tensor(loadmat("../data/elevators.mat")["data"])
X = data[:, :-1]
X = X - X.min(0)[0]
X = 2 * (X / X.max(0)[0]) - 1
y = data[:, -1]
N_inducing = 30

train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()

from gpytorch.means import ConstantMean
from gpytorch.kernels import ScaleKernel, RBFKernel, InducingPointKernel
from gpytorch.distributions import MultivariateNormal

class GPRegressionModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPRegressionModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = ConstantMean()
        self.base_covar_module = ScaleKernel(RBFKernel())
        self.covar_module = InducingPointKernel(self.base_covar_module, inducing_points=train_x[:N_inducing, :].clone(), likelihood=likelihood)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPRegressionModel(train_x, train_y, likelihood)

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

# Find optimal model hyperparameters
model.train()
likelihood.train()

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

###
# Very slow!
from botorch import fit_gpytorch_mll
output = model(train_x)
loss = -mll(output, train_y)
loss.backward()

# Even slower!
fit_gpytorch_mll(mll, optimizer_kwargs={
                "options": {"disp": True},
            },)

Expected Behavior

Fast runtime with a small number of inducing points when doing exact computations without approximations.

System information

GPyTorch Version: 1.11 PyTorch Version: 2.0.0 Botorch Version: 0.8.5 OS: Mac M1 Ventura 13.4.1 (22F82)

pipme avatar Jul 03 '23 18:07 pipme

Thanks for flagging this

@saitcakmak, @dme65 re side effects of changing gpytorch settings upon import in botorch - We had talked about this before, can you please look into finding a way to do this without changing non-botorch behavior?

Balandat avatar Jul 03 '23 19:07 Balandat

Hi @Balandat, thanks for the reply. Maybe I was unclear in the description. Actually I wanted to fit SGPR with traditional exact computations (Cholesky decomposition). Therefore the modifications by botorch are total reasonable as far as I can see and are exactly what I want to do, i.e., turn off fast approximations. But I don’t know why these modifications make SGPR so slow as this shouldn’t happen with a moderately large number of inducing points (e.g. 100).

Thanks!

pipme avatar Jul 04 '23 04:07 pipme

GPyTorch uses approximate computations in multiple places by default. In this case, the difference is due to the approximate log-probability computations of 13279-dim MVN. If I enable fast (i.e., approximate) log prob computations, then it is quite fast again.

import linear_operator.settings as linop_settings
linop_settings._fast_log_prob._default = True

@Balandat I don't think there's a clean way of doing this for all BoTorch models and not affecting any non-BoTorch model at the same time. The settings are global, so enforcing them locally is a messy business. We could offer a simple helper for reverting the changes we make by default. Otherwise, I am quite happy with the benefits of the current setup.

saitcakmak avatar Jul 04 '23 06:07 saitcakmak

the difference is due to the approximate log-probability computations of 13279-dim MVN

Thanks for figuring it out. In this case, doesn't it mean that GPyTorch didn't implement SGPR properly? Since in SGPR one only needs to do “expensive” computation with an $M \times M$ covariance matrix where $M$ is the number of inducing points, i.e., the 13279-dim MVN involved in the ELBO doesn't require operating on a $13279 \times 13279$ matrix (see the derivations here for example). That's the whole point of using SGPR instead of GP 😄.

Btw, mll = gpytorch.mlls.ExactMarginalLogLikelihood might not be an appropriate name/loss for SGPR and it should be something like ExactVariationalELBO to compute the SGPR's ELBO, which seems not implemented in GPyTorch.

pipme avatar Jul 04 '23 07:07 pipme

Not exactly. We're talking about different covariance matrices here. The 13279-dim one is the output of model(train_x), where train_x itself is 13279 x d. It is computed as the product of (13279 x M) (M x M) (M x 13279) matrices.

To leverage SGPR, you may want to train the model using mini batches, rather than using LBFGS with the full training data. This will make it scale to arbitrary amount of data. From what I've seen, the MLL computation with approximate log probability is also quite fast, so using that with LFBGS is always an option.

I am not that familiar with the ELBO, so I can't offer advice there.

saitcakmak avatar Jul 04 '23 19:07 saitcakmak

Just to clarify, no approximate math (i.e., iterative methods) is used using default settings in GPyTorch when running SGPR. With an InducingPointKernel, you'll ultimately end up with a LowRankRootAddedDiagLinearOperator, the default behavior for inv_quad_logdet of which is to use Cholesky and Woodbury for both the solve and the log determinant in O(nm^2) time: https://github.com/cornellius-gp/linear_operator/blob/7c5aabd8291146790e6f178ed6735f86163aca55/linear_operator/operators/low_rank_root_added_diag_linear_operator.py#L78

BoTorch probably isn't intended to fully turn off the usage of e.g. LowRankRootAddedDiagLinearOperator. It sounds like the real problem here is that, once again the settings in GPyTorch controlling this are too broadly named. What we should probably do is change things so that it is possible to turn off the use of iterative solvers, but never to turn off the usage of just obvious, efficient direct linear algebra where it's possible.

Come to think of it, the settings controlling approximate math shouldn't really affect the running of SGPR at all since inv_quad_logdet is directly overwritten and the corresponding Functions are never called -- like, nowhere in the code path of computing the MLL or its derivative should check those settings at all for SGPR. Is it possible that something broke in pulling out linear_operator?

jacobrgardner avatar Jul 04 '23 20:07 jacobrgardner