Scalable Additive-Structure GP Classification (CUDA) (w/ KISS-GP)

Introduction

This example shows how to use a AdditiveGridInterpolationVariationalStrategy module. This classifcation module is designed for when the function you’re modeling has an additive decomposition over dimension. This is equivalent to using a covariance function that additively decomposes over dimensions:

\[k(\mathbf{x},\mathbf{x'}) = \sum_{i=1}^{d}k([\mathbf{x}]_{i}, [\mathbf{x'}]_{i})\]

where \([\mathbf{x}]_{i}\) denotes the ith component of the vector \(\mathbf{x}\). Example applications of this include use in Bayesian optimization, and when performing deep kernel learning.

The use of inducing points allows for scaling up the training data by making computational complexity linear instead of cubic in the number of data points.

In this example, we’re performing classification on a two dimensional toy dataset that is: - Defined in [-1, 1]x[-1, 1] - Valued 1 in [-0.5, 0.5]x[-0.5, 0.5] - Valued -1 otherwise

The above function doesn’t have an obvious additive decomposition, but it turns out that this function is can be very well approximated by the kernel anyways.

[1]:
# High-level imports
import math
from math import exp
import torch
import gpytorch
from matplotlib import pyplot as plt

# Make inline plots
%matplotlib inline

Generate toy dataset

[2]:
n = 51
train_x = torch.zeros(n ** 2, 2)
train_x[:, 0].copy_(torch.linspace(-1, 1, n).repeat(n))
train_x[:, 1].copy_(torch.linspace(-1, 1, n).unsqueeze(1).repeat(1, n).view(-1))
train_y = (train_x[:, 0].abs().lt(0.5)).float() * (train_x[:, 1].abs().lt(0.5)).float()

train_x = train_x.cuda()
train_y = train_y.cuda()

Define the model

In contrast to the most basic classification models, this model uses an AdditiveGridInterpolationVariationalStrategy. This causes two key changes in the model. First, the model now specifically assumes that the input to forward, x, is to be additive decomposed. Thus, although the model below defines an RBFKernel as the covariance function, because we extend this base class, the additive decomposition discussed above will be imposed.

Second, this model automatically assumes we will be using scalable kernel interpolation (SKI) for each dimension. Because of the additive decomposition, we only provide one set of grid bounds to the base class constructor, as the same grid will be used for all dimensions. It is recommended that you scale your training and test data appropriately.

[3]:
from gpytorch.models import ApproximateGP
from gpytorch.variational import AdditiveGridInterpolationVariationalStrategy, CholeskyVariationalDistribution
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.likelihoods import BernoulliLikelihood
from gpytorch.means import ConstantMean
from gpytorch.distributions import MultivariateNormal

class GPClassificationModel(ApproximateGP):
    def __init__(self, grid_size=64, grid_bounds=([-1, 1],)):
        variational_distribution = CholeskyVariationalDistribution(num_inducing_points=grid_size, batch_size=2)
        variational_strategy = AdditiveGridInterpolationVariationalStrategy(self,
                                                                            grid_size=grid_size,
                                                                            grid_bounds=grid_bounds,
                                                                            num_dim=2,
                                                                            variational_distribution=variational_distribution)
        super(GPClassificationModel, self).__init__(variational_strategy)
        self.mean_module = ConstantMean()
        self.covar_module = ScaleKernel(RBFKernel(ard_num_dims=1))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        latent_pred = MultivariateNormal(mean_x, covar_x)
        return latent_pred

# Cuda the model and likelihood function
model = GPClassificationModel().cuda()
likelihood = gpytorch.likelihoods.BernoulliLikelihood().cuda()

Training the model

Once the model has been defined, the training loop looks very similar to other variational models we’ve seen in the past. We will optimize the variational lower bound as our objective function. In this case, although variational inference in GPyTorch supports stochastic gradient descent, we choose to do batch optimization due to the relatively small toy dataset.

For an example of using the AdditiveGridInterpolationVariationalStrategy model with stochastic gradient descent, see the dkl_mnist example.

[4]:
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

# "Loss" for GPs - the marginal log likelihood
# n_data refers to the number of training datapoints
mll = gpytorch.mlls.VariationalELBO(likelihood, model, num_data=train_y.numel())

# Training function
def train(num_iter=50):
    for i in range(num_iter):
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, num_iter, loss.item()))
        optimizer.step()

%time train()
Iter 1/50 - Loss: 1.115
Iter 2/50 - Loss: 1.108
Iter 3/50 - Loss: 1.873
Iter 4/50 - Loss: 1.382
Iter 5/50 - Loss: 1.366
Iter 6/50 - Loss: 1.286
Iter 7/50 - Loss: 1.021
Iter 8/50 - Loss: 0.837
Iter 9/50 - Loss: 0.809
Iter 10/50 - Loss: 0.801
Iter 11/50 - Loss: 0.706
Iter 12/50 - Loss: 0.611
Iter 13/50 - Loss: 0.575
Iter 14/50 - Loss: 0.550
Iter 15/50 - Loss: 0.513
Iter 16/50 - Loss: 0.474
Iter 17/50 - Loss: 0.450
Iter 18/50 - Loss: 0.426
Iter 19/50 - Loss: 0.395
Iter 20/50 - Loss: 0.369
Iter 21/50 - Loss: 0.351
Iter 22/50 - Loss: 0.336
Iter 23/50 - Loss: 0.318
Iter 24/50 - Loss: 0.301
Iter 25/50 - Loss: 0.282
Iter 26/50 - Loss: 0.264
Iter 27/50 - Loss: 0.250
Iter 28/50 - Loss: 0.238
Iter 29/50 - Loss: 0.231
Iter 30/50 - Loss: 0.217
Iter 31/50 - Loss: 0.205
Iter 32/50 - Loss: 0.194
Iter 33/50 - Loss: 0.189
Iter 34/50 - Loss: 0.180
Iter 35/50 - Loss: 0.173
Iter 36/50 - Loss: 0.166
Iter 37/50 - Loss: 0.160
Iter 38/50 - Loss: 0.155
Iter 39/50 - Loss: 0.149
Iter 40/50 - Loss: 0.141
Iter 41/50 - Loss: 0.138
Iter 42/50 - Loss: 0.141
Iter 43/50 - Loss: 0.131
Iter 44/50 - Loss: 0.135
Iter 45/50 - Loss: 0.122
Iter 46/50 - Loss: 0.119
Iter 47/50 - Loss: 0.123
Iter 48/50 - Loss: 0.116
Iter 49/50 - Loss: 0.107
Iter 50/50 - Loss: 0.108
CPU times: user 5.95 s, sys: 640 ms, total: 6.59 s
Wall time: 7.13 s

Test the model

Next we test the model and plot the decision boundary. Despite the function we are optimizing not having an obvious additive decomposition, the model provides accurate results.

[5]:
# Switch the model and likelihood into the evaluation mode
model.eval()
likelihood.eval()

# Start the plot, 4x3in
f, ax = plt.subplots(1, 1, figsize=(4, 3))

n = 150
test_x = torch.zeros(n ** 2, 2)
test_x[:, 0].copy_(torch.linspace(-1, 1, n).repeat(n))
test_x[:, 1].copy_(torch.linspace(-1, 1, n).unsqueeze(1).repeat(1, n).view(-1))
# Cuda variable of test data
test_x = test_x.cuda()

with torch.no_grad():
    predictions = likelihood(model(test_x))

# prob<0.5 --> label -1 // prob>0.5 --> label 1
pred_labels = predictions.mean.ge(0.5).float().cpu()
# Colors = yellow for 1, red for -1
color = []
for i in range(len(pred_labels)):
    if pred_labels[i] == 1:
        color.append('y')
    else:
        color.append('r')

# Plot data a scatter plot
ax.scatter(test_x[:, 0].cpu(), test_x[:, 1].cpu(), color=color, s=1)
[5]:
<matplotlib.collections.PathCollection at 0x7fbfdc12a908>
../../_images/examples_07_Scalable_GP_Classification_Multidimensional_KISSGP_Additive_Classification_CUDA_9_1.png
[ ]: