gpytorch
gpytorch copied to clipboard
VNNGP with Batches
Discussed in https://github.com/cornellius-gp/gpytorch/discussions/2300
Originally posted by Turakar March 12, 2023 I am trying to get a VNNGP to work in a batched setting. To this end, I tried the following code, which is based on the tutorial.
import math
import matplotlib.pyplot as plt
import torch
from torch import Tensor
from tqdm.auto import tqdm
from gpytorch import settings
from gpytorch.distributions import MultivariateNormal
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import GaussianLikelihood
from gpytorch.means import ConstantMean
from gpytorch.mlls import PredictiveLogLikelihood
from gpytorch.models import ApproximateGP
from gpytorch.variational import MeanFieldVariationalDistribution, NNVariationalStrategy
class BatchGPModel(ApproximateGP):
def __init__(self, train_x: Tensor):
batch_shape = train_x.shape[:-2]
inducing_points = torch.clone(train_x)
variational_distribution = MeanFieldVariationalDistribution(inducing_points.size(-2), batch_shape=batch_shape)
variational_strategy = NNVariationalStrategy(self, inducing_points, variational_distribution, 25, 25)
super().__init__(variational_strategy)
self.mean_module = ConstantMean(batch_shape=batch_shape)
self.covar_module = ScaleKernel(RBFKernel(batch_shape=batch_shape), batch_shape=batch_shape)
self.likelihood = GaussianLikelihood(batch_shape=batch_shape)
def forward(self, x: Tensor) -> MultivariateNormal:
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return MultivariateNormal(mean_x, covar_x)
def __call__(self, x, prior=False, **kwargs):
if x is not None:
if x.dim() == 1:
x = x.unsqueeze(-1)
return self.variational_strategy(x=x, prior=False, **kwargs)
def main():
x = torch.linspace(0, 1, 100)
train_y = torch.stack(
[
torch.sin(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
torch.sin(x * (2 * math.pi)) + 2 * torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
-torch.cos(x * (2 * math.pi)) + torch.randn(x.size()) * 0.2,
],
0,
)
train_x = torch.stack([x] * 4).unsqueeze(-1)
num_tasks = 4
# initialize model
model = BatchGPModel(train_x)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
mll = PredictiveLogLikelihood(model.likelihood, model, num_data=x.size(0))
num_batches = model.variational_strategy._total_training_batches
epochs_iter = tqdm(range(50), desc="Epoch")
for _ in epochs_iter:
minibatch_iter = tqdm(range(num_batches), desc="Minibatch", leave=False)
for _ in minibatch_iter:
optimizer.zero_grad()
output = model(x=None)
current_training_indices = model.variational_strategy.current_training_indices
y_batch = train_y[:, current_training_indices]
loss = -mll(output, y_batch).sum()
minibatch_iter.set_postfix(loss=loss.item())
loss.backward()
optimizer.step()
# Get into evaluation (predictive posterior) mode
model.eval()
# Initialize plots
fig, axs = plt.subplots(1, num_tasks, figsize=(4 * num_tasks, 3))
# Make predictions
with torch.no_grad(), settings.fast_pred_var():
test_x = torch.stack([torch.linspace(0, 1, 51)] * num_tasks).unsqueeze(-1)
predictions = model.likelihood(model(test_x))
mean = predictions.mean
lower, upper = predictions.confidence_region()
for task, ax in enumerate(axs):
# Plot training data as black stars
ax.plot(x.detach().numpy(), train_y[task].detach().numpy(), "k*")
# Predictive mean as blue line
ax.plot(test_x[0, :, 0].numpy(), mean[task].numpy(), "b")
# Shade in confidence
ax.fill_between(test_x[0, :, 0].numpy(), lower[task].numpy(), upper[task].numpy().T, alpha=0.5)
ax.set_ylim([-3, 3])
ax.legend(["Observed Data", "Mean", "Confidence"])
ax.set_title(f"Task {task + 1}")
fig.tight_layout()
plt.show()
if __name__ == "__main__":
main()
However, this does not work, as the _stochastic_kl_helper() in NNVariationalStrategy calls forward() of the covar_module with an input of shape (4, 25, 25, 1), while a shape of (4, n, 1) is expected. Did I get something wrong about the usage of NNVariationalStrategy or is there something wrong in GPyTorch? This is the exact stacktrace:
Traceback (most recent call last):
File "/path/to/gpytorch/snippet_variational_batch_sogp.py", line 109, in <module>
main()
File "/path/to/gpytorch/snippet_variational_batch_sogp.py", line 71, in main
output = model(x=None)
File "/path/to/gpytorch/snippet_variational_batch_sogp.py", line 40, in __call__
return self.variational_strategy(x=x, prior=False, **kwargs)
File "/path/to/gpytorch/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 131, in __call__
return self.forward(x, self.inducing_points, None, None)
File "/path/to/gpytorch/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 168, in forward
kl = self._kl_divergence(kl_indices)
File "/path/to/gpytorch/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 325, in _kl_divergence
kl = self._stochastic_kl_helper(kl_indices) * self.M / len(kl_indices)
File "/path/to/gpytorch/gpytorch/variational/nearest_neighbor_variational_strategy.py", line 273, in _stochastic_kl_helper
cov = self.model.covar_module.forward(nearest_neighbors, nearest_neighbors)
File "/path/to/gpytorch/gpytorch/kernels/scale_kernel.py", line 109, in forward
orig_output = self.base_kernel.forward(x1, x2, diag=diag, last_dim_is_batch=last_dim_is_batch, **params)
File "/path/to/gpytorch/gpytorch/kernels/rbf_kernel.py", line 80, in forward
return RBFCovariance.apply(
File "/path/to/gpytorch/gpytorch/functions/rbf_covariance.py", line 12, in forward
x1_ = x1.div(lengthscale)
RuntimeError: The size of tensor a (25) must match the size of tensor b (4) at non-singleton dimension 1
Does somebody have an idea what's going on here? Maybe @LuhuanWu or @gpleiss ?
Submitted PR #2375 , and replied in #2300.