Stochastic Variational GP Regression

Overview

In this notebook, we’ll give an overview of how to use SVGP stochastic variational regression ((https://arxiv.org/pdf/1411.2005.pdf)) to rapidly train using minibatches on the 3droad UCI dataset with hundreds of thousands of training examples. This is one of the more common use-cases of variational inference for GPs.

If you are unfamiliar with variational inference, we recommend the following resources: - Variational Inference: A Review for Statisticians by David M. Blei, Alp Kucukelbir, Jon D. McAuliffe. - Scalable Variational Gaussian Process Classification by James Hensman, Alex Matthews, Zoubin Ghahramani.

[1]:
import tqdm
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make plots inline
%matplotlib inline

For this example notebook, we’ll be using the song UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we’ll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

Note: Running the next cell will attempt to download a ~136 MB file to the current directory.

[2]:
import urllib.request
import os
from scipy.io import loadmat
from math import floor


# this is for running the notebook in our testing framework
smoke_test = ('CI' in os.environ)


if not smoke_test and not os.path.isfile('../elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', '../elevators.mat')


if smoke_test:  # this is for running the notebook in our testing framework
    X, y = torch.randn(1000, 3), torch.randn(1000)
else:
    data = torch.Tensor(loadmat('../elevators.mat')['data'])
    X = data[:, :-1]
    X = X - X.min(0)[0]
    X = 2 * (X / X.max(0)[0]) - 1
    y = data[:, -1]


train_n = int(floor(0.8 * len(X)))
train_x = X[:train_n, :].contiguous()
train_y = y[:train_n].contiguous()

test_x = X[train_n:, :].contiguous()
test_y = y[train_n:].contiguous()

if torch.cuda.is_available():
    train_x, train_y, test_x, test_y = train_x.cuda(), train_y.cuda(), test_x.cuda(), test_y.cuda()

Creating a DataLoader

The next step is to create a torch DataLoader that will handle getting us random minibatches of data. This involves using the standard TensorDataset and DataLoader modules provided by PyTorch.

In this notebook we’ll be using a fairly large batch size of 1024 just to make optimization run faster, but you could of course change this as you so choose.

[3]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)

test_dataset = TensorDataset(test_x, test_y)
test_loader = DataLoader(test_dataset, batch_size=1024, shuffle=False)

Creating a SVGP Model

For most variational/approximate GP models, you will need to construct the following GPyTorch objects:

  1. A GP Model (gpytorch.models.ApproximateGP) - This handles basic variational inference.

  2. A Variational distribution (gpytorch.variational._VariationalDistribution) - This tells us what form the variational distribution q(u) should take.

  3. A Variational strategy (gpytorch.variational._VariationalStrategy) - This tells us how to transform a distribution q(u) over the inducing point values to a distribution q(f) over the latent function values for some input x.

Here, we use a VariationalStrategy with learn_inducing_points=True, and a CholeskyVariationalDistribution. These are the most straightforward and common options.

The GP Model

The ApproximateGP model is GPyTorch’s simplest approximate inference model. It approximates the true posterior with a distribution specified by a VariationalDistribution, which is most commonly some form of MultivariateNormal distribution. The model defines all the variational parameters that are needed, and keeps all of this information under the hood.

The components of a user built ApproximateGP model in GPyTorch are:

  1. An __init__ method that constructs a mean module, a kernel module, a variational distribution object and a variational strategy object. This method should also be responsible for construting whatever other modules might be necessary.

  2. A forward method that takes in some \(n \times d\) data x and returns a MultivariateNormal with the prior mean and covariance evaluated at x. In other words, we return the vector \(\mu(x)\) and the \(n \times n\) matrix \(K_{xx}\) representing the prior mean and covariance matrix of the GP.

[4]:
from gpytorch.models import ApproximateGP
from gpytorch.variational import CholeskyVariationalDistribution
from gpytorch.variational import VariationalStrategy

class GPModel(ApproximateGP):
    def __init__(self, inducing_points):
        variational_distribution = CholeskyVariationalDistribution(inducing_points.size(0))
        variational_strategy = VariationalStrategy(self, inducing_points, variational_distribution, learn_inducing_locations=True)
        super(GPModel, self).__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

inducing_points = train_x[:500, :]
model = GPModel(inducing_points=inducing_points)
likelihood = gpytorch.likelihoods.GaussianLikelihood()

if torch.cuda.is_available():
    model = model.cuda()
    likelihood = likelihood.cuda()

Training the Model

The cell below trains the model above, learning both the hyperparameters of the Gaussian process and the parameters of the neural network in an end-to-end fashion using Type-II MLE.

Unlike when using the exact GP marginal log likelihood, performing variational inference allows us to make use of stochastic optimization techniques. For this example, we’ll do one epoch of training. Given the small size of the neural network relative to the size of the dataset, this should be sufficient to achieve comparable accuracy to what was observed in the DKL paper.

The optimization loop differs from the one seen in our more simple tutorials in that it involves looping over both a number of training iterations (epochs) and minibatches of the data. However, the basic process is the same: for each minibatch, we forward through the model, compute the loss (the VariationalELBO or ELBO), call backwards, and do a step of optimization.

[5]:
num_epochs = 1 if smoke_test else 4


model.train()
likelihood.train()

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
    {'params': likelihood.parameters()},
], lr=0.01)

# Our loss object. We're using the VariationalELBO
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.size(0))


epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter:
    # Within each iteration, we will go over each minibatch of data
    minibatch_iter = tqdm.notebook.tqdm(train_loader, desc="Minibatch", leave=False)
    for x_batch, y_batch in minibatch_iter:
        optimizer.zero_grad()
        output = model(x_batch)
        loss = -mll(output, y_batch)
        minibatch_iter.set_postfix(loss=loss.item())
        loss.backward()
        optimizer.step()

Making Predictions

The next cell gets the predictive covariance for the test set (and also technically gets the predictive mean, stored in preds.mean()). Because the test set is substantially smaller than the training set, we don’t need to make predictions in mini batches here, although this can be done by passing in minibatches of test_x rather than the full tensor.

[6]:
model.eval()
likelihood.eval()
means = torch.tensor([0.])
with torch.no_grad():
    for x_batch, y_batch in test_loader:
        preds = model(x_batch)
        means = torch.cat([means, preds.mean.cpu()])
means = means[1:]
[7]:
print('Test MAE: {}'.format(torch.mean(torch.abs(means - test_y.cpu()))))
Test MAE: 0.11148034781217575
[ ]: