gpytorch icon indicating copy to clipboard operation
gpytorch copied to clipboard

Re-think conditions for switching from Cholesky to CG

Open gpleiss opened this issue 1 year ago • 3 comments

See #2078.

Cholesky seems to be outperforming CG now, even when N=4000. It seems like dimensionality is an important factor (low dimensionality tends to have worse condition numbers.)

import gpytorch
import torch
import time

# Test data
torch.manual_seed(0)
train_x = torch.randn(8000, 2)
train_y = torch.randn(8000)
test_x = torch.randn(5100, 2)

# Construct model
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ZeroMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
model = ExactGPModel(train_x, train_y, gpytorch.likelihoods.GaussianLikelihood())

if torch.cuda.is_available():
    model.cuda()
    train_x = train_x.cuda()
    train_y = train_y.cuda()
    test_x = test_x.cuda()

# Train model
model.train()
model.likelihood.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = gpytorch.mlls.ExactMarginalLogLikelihood(model.likelihood, model)
for i in range(50):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    optimizer.step()
model.eval()

with gpytorch.settings.verbose_linalg():
    model.train()
    model.eval()
    start_time = time.time()
    with gpytorch.settings.fast_computations(solves=False):
        preds = model.likelihood(model(test_x))
    print(time.time() - start_time)

    model.train()
    model.eval()
    start_time = time.time()
    with gpytorch.settings.fast_pred_var():
        preds = model.likelihood(model(test_x))
    print(time.time() - start_time)
LinAlg (Verbose) - DEBUG - Running Cholesky on a matrix of size torch.Size([8000, 8000]).
covar
LinAlg (Verbose) - DEBUG - Running Cholesky on a matrix of size torch.Size([8000, 8000]).
done
0.24354290962219238
LinAlg (Verbose) - DEBUG - Running Pivoted Cholesky on a torch.Size([8000, 8000]) RHS for 15 iterations.
LinAlg (Verbose) - DEBUG - Running CG on a torch.Size([8000, 1]) RHS for 1000 iterations (tol=0.01). Output: torch.Size([8000, 1]).
covar
LinAlg (Verbose) - DEBUG - Running Lanczos on a torch.Size([8000, 8000]) matrix with a torch.Size([8000, 1]) RHS for 100 iterations.
LinAlg (Verbose) - DEBUG - Running symeig on a matrix of size torch.Size([42, 42]).
0.31032729148864746

(Also note that we seem to be running Cholesky twice (once for predictive mean, once for solves). However, this issue should disappear with the prediction strategy refactor that @JonathanWenger is working on.

gpleiss avatar May 26 '23 13:05 gpleiss

After the training (for i in range(50):) the model mode is switched to model.eval(). Then after activating the with ...verbose_linalg(): context, it is switched back to model.train() just to be switched back again to model.eval(). Is there a reason to do this? I tried without switching back and forth, but it seems to still work.

...
with gpytorch.settings.verbose_linalg():
    #model.train()     ##### commented this line
    #model.eval()      ##### commented this line
    start_time = time.time()
    with gpytorch.settings.fast_computations(solves=False):
        preds = model.likelihood(model(test_x))
    print(time.time() - start_time)

    model.train()
    model.eval()
    start_time = time.time()
    with gpytorch.settings.fast_pred_var():
        preds = model.likelihood(model(test_x))
    print(time.time() - start_time)

grafik

Edit: I tested it using GPyTorch v1.11

OliverAh avatar Apr 26 '24 15:04 OliverAh

it is switched back to model.train() just to be switched back again to model.eval(). Is there a reason to do this?

Did you check the timing if you don't do that? I assume this is done to clear the caches to have an apples-to-apples comparison.

Balandat avatar Apr 27 '24 01:04 Balandat

I did, using 3 runs each. Without switching back and forth it is approx. 5% faster, when not specifying the contexts (with gpytorch.settings...). When specifying the contexts, the runtime seemed to vary up- and downwards, so I would say no significance here.

For comparison of runtimes it seems reasonable. But my question aimed for technical necessity, e.g. would I also be required to do it in production runs to ensure correctness of results, etc. Rereading my question that was not very clear, sorry for that.

OliverAh avatar Apr 29 '24 12:04 OliverAh