Scalable GP Classification in 1D (w/ KISS-GP)

This example shows how to use a GridInducingVariationalGP module. This classification module is designed for when the inputs of the function you’re modeling are one-dimensional.

The use of inducing points allows for scaling up the training data by making computational complexity linear instead of cubic.

In this example, we’re modeling a function that is periodically labeled cycling every 1/8 (think of a square wave with period 1/4)

This notebook doesn’t use cuda, in general we recommend GPU use if possible and most of our notebooks utilize cuda as well.

Kernel interpolation for scalable structured Gaussian processes (KISS-GP) was introduced in this paper: http://proceedings.mlr.press/v37/wilson15.pdf

KISS-GP with SVI for classification was introduced in this paper: https://papers.nips.cc/paper/6426-stochastic-variational-deep-kernel-learning.pdf

In [1]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt
from math import exp

%matplotlib inline
%load_ext autoreload
%autoreload 2
In [2]:
train_x = torch.linspace(0, 1, 26)
train_y = torch.sign(torch.cos(train_x * (2 * math.pi)))
In [3]:
class GPClassificationModel(gpytorch.models.GridInducingVariationalGP):
    def __init__(self):
        super(GPClassificationModel, self).__init__(grid_size=32, grid_bounds=[(0, 1)])
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(
                log_lengthscale_prior=gpytorch.priors.SmoothedBoxPrior(
                    exp(0), exp(3), sigma=0.1, log_transform=True
                )
            )
        )

    def forward(self,x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)
        return latent_pred


model = GPClassificationModel()
likelihood = gpytorch.likelihoods.BernoulliLikelihood()
In [4]:
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.SGD([
    {'params': model.parameters()},
    # BernoulliLikelihood has no parameters
], lr=0.1)

# "Loss" for GPs - the marginal log likelihood
# n_data refers to the amount of training data
mll = gpytorch.mlls.VariationalMarginalLogLikelihood(likelihood, model, num_data=len(train_y))

def train():
    num_iter = 200
    for i in range(num_iter):
        optimizer.zero_grad()
        output = model(train_x)
        loss = -mll(output, train_y)
        loss.backward()
        print('Iter %d/%d - Loss: %.3f' % (i + 1, num_iter, loss.item()))
        optimizer.step()

# Get clock time
%time train()
Iter 1/200 - Loss: 6097.150
Iter 2/200 - Loss: 392.698
Iter 3/200 - Loss: 352.528
Iter 4/200 - Loss: 322.668
Iter 5/200 - Loss: 324.150
Iter 6/200 - Loss: 362.064
Iter 7/200 - Loss: 416.016
Iter 8/200 - Loss: 298.819
Iter 9/200 - Loss: 252.523
Iter 10/200 - Loss: 211.480
Iter 11/200 - Loss: 321.593
Iter 12/200 - Loss: 289.411
Iter 13/200 - Loss: 347.734
Iter 14/200 - Loss: 330.747
Iter 15/200 - Loss: 296.441
Iter 16/200 - Loss: 320.904
Iter 17/200 - Loss: 309.613
Iter 18/200 - Loss: 281.426
Iter 19/200 - Loss: 296.923
Iter 20/200 - Loss: 301.174
Iter 21/200 - Loss: 250.977
Iter 22/200 - Loss: 343.925
Iter 23/200 - Loss: 240.633
Iter 24/200 - Loss: 281.263
Iter 25/200 - Loss: 271.904
Iter 26/200 - Loss: 325.001
Iter 27/200 - Loss: 272.500
Iter 28/200 - Loss: 271.221
Iter 29/200 - Loss: 284.778
Iter 30/200 - Loss: 243.273
Iter 31/200 - Loss: 292.012
Iter 32/200 - Loss: 302.940
Iter 33/200 - Loss: 290.159
Iter 34/200 - Loss: 320.918
Iter 35/200 - Loss: 259.498
Iter 36/200 - Loss: 277.864
Iter 37/200 - Loss: 293.995
Iter 38/200 - Loss: 223.749
Iter 39/200 - Loss: 265.449
Iter 40/200 - Loss: 265.746
Iter 41/200 - Loss: 250.393
Iter 42/200 - Loss: 261.780
Iter 43/200 - Loss: 239.201
Iter 44/200 - Loss: 317.880
Iter 45/200 - Loss: 326.695
Iter 46/200 - Loss: 282.346
Iter 47/200 - Loss: 268.213
Iter 48/200 - Loss: 278.097
Iter 49/200 - Loss: 246.392
Iter 50/200 - Loss: 321.273
Iter 51/200 - Loss: 274.065
Iter 52/200 - Loss: 263.765
Iter 53/200 - Loss: 219.523
Iter 54/200 - Loss: 303.933
Iter 55/200 - Loss: 256.324
Iter 56/200 - Loss: 203.140
Iter 57/200 - Loss: 289.328
Iter 58/200 - Loss: 261.303
Iter 59/200 - Loss: 225.208
Iter 60/200 - Loss: 222.645
Iter 61/200 - Loss: 234.964
Iter 62/200 - Loss: 292.547
Iter 63/200 - Loss: 233.788
Iter 64/200 - Loss: 231.033
Iter 65/200 - Loss: 194.131
Iter 66/200 - Loss: 230.455
Iter 67/200 - Loss: 252.459
Iter 68/200 - Loss: 227.107
Iter 69/200 - Loss: 252.148
Iter 70/200 - Loss: 229.926
Iter 71/200 - Loss: 244.014
Iter 72/200 - Loss: 210.347
Iter 73/200 - Loss: 264.777
Iter 74/200 - Loss: 235.150
Iter 75/200 - Loss: 239.858
Iter 76/200 - Loss: 205.147
Iter 77/200 - Loss: 199.181
Iter 78/200 - Loss: 235.487
Iter 79/200 - Loss: 250.423
Iter 80/200 - Loss: 211.550
Iter 81/200 - Loss: 211.175
Iter 82/200 - Loss: 213.312
Iter 83/200 - Loss: 197.529
Iter 84/200 - Loss: 249.012
Iter 85/200 - Loss: 241.818
Iter 86/200 - Loss: 226.489
Iter 87/200 - Loss: 251.521
Iter 88/200 - Loss: 203.768
Iter 89/200 - Loss: 220.160
Iter 90/200 - Loss: 243.473
Iter 91/200 - Loss: 214.500
Iter 92/200 - Loss: 213.951
Iter 93/200 - Loss: 245.208
Iter 94/200 - Loss: 201.523
Iter 95/200 - Loss: 199.266
Iter 96/200 - Loss: 214.818
Iter 97/200 - Loss: 228.327
Iter 98/200 - Loss: 243.201
Iter 99/200 - Loss: 193.552
Iter 100/200 - Loss: 226.596
Iter 101/200 - Loss: 207.586
Iter 102/200 - Loss: 229.452
Iter 103/200 - Loss: 211.403
Iter 104/200 - Loss: 194.898
Iter 105/200 - Loss: 192.584
Iter 106/200 - Loss: 218.825
Iter 107/200 - Loss: 197.878
Iter 108/200 - Loss: 201.669
Iter 109/200 - Loss: 246.887
Iter 110/200 - Loss: 232.580
Iter 111/200 - Loss: 208.174
Iter 112/200 - Loss: 217.168
Iter 113/200 - Loss: 195.321
Iter 114/200 - Loss: 246.281
Iter 115/200 - Loss: 249.421
Iter 116/200 - Loss: 200.820
Iter 117/200 - Loss: 191.208
Iter 118/200 - Loss: 227.009
Iter 119/200 - Loss: 264.285
Iter 120/200 - Loss: 200.157
Iter 121/200 - Loss: 209.431
Iter 122/200 - Loss: 190.169
Iter 123/200 - Loss: 223.926
Iter 124/200 - Loss: 231.914
Iter 125/200 - Loss: 196.829
Iter 126/200 - Loss: 176.027
Iter 127/200 - Loss: 197.739
Iter 128/200 - Loss: 163.040
Iter 129/200 - Loss: 221.040
Iter 130/200 - Loss: 209.215
Iter 131/200 - Loss: 169.048
Iter 132/200 - Loss: 134.395
Iter 133/200 - Loss: 194.889
Iter 134/200 - Loss: 239.895
Iter 135/200 - Loss: 207.784
Iter 136/200 - Loss: 224.677
Iter 137/200 - Loss: 185.859
Iter 138/200 - Loss: 194.485
Iter 139/200 - Loss: 198.281
Iter 140/200 - Loss: 177.267
Iter 141/200 - Loss: 177.465
Iter 142/200 - Loss: 196.033
Iter 143/200 - Loss: 143.547
Iter 144/200 - Loss: 187.040
Iter 145/200 - Loss: 203.045
Iter 146/200 - Loss: 192.756
Iter 147/200 - Loss: 180.532
Iter 148/200 - Loss: 175.648
Iter 149/200 - Loss: 191.526
Iter 150/200 - Loss: 166.489
Iter 151/200 - Loss: 220.140
Iter 152/200 - Loss: 167.087
Iter 153/200 - Loss: 148.467
Iter 154/200 - Loss: 220.460
Iter 155/200 - Loss: 160.580
Iter 156/200 - Loss: 196.464
Iter 157/200 - Loss: 185.087
Iter 158/200 - Loss: 148.367
Iter 159/200 - Loss: 158.299
Iter 160/200 - Loss: 187.548
Iter 161/200 - Loss: 181.689
Iter 162/200 - Loss: 172.187
Iter 163/200 - Loss: 191.411
Iter 164/200 - Loss: 167.754
Iter 165/200 - Loss: 138.704
Iter 166/200 - Loss: 162.195
Iter 167/200 - Loss: 186.930
Iter 168/200 - Loss: 182.635
Iter 169/200 - Loss: 158.236
Iter 170/200 - Loss: 160.126
Iter 171/200 - Loss: 180.415
Iter 172/200 - Loss: 187.367
Iter 173/200 - Loss: 163.659
Iter 174/200 - Loss: 184.058
Iter 175/200 - Loss: 216.402
Iter 176/200 - Loss: 169.361
Iter 177/200 - Loss: 183.626
Iter 178/200 - Loss: 174.367
Iter 179/200 - Loss: 157.275
Iter 180/200 - Loss: 171.675
Iter 181/200 - Loss: 192.713
Iter 182/200 - Loss: 158.222
Iter 183/200 - Loss: 173.345
Iter 184/200 - Loss: 150.134
Iter 185/200 - Loss: 189.955
Iter 186/200 - Loss: 170.120
Iter 187/200 - Loss: 200.875
Iter 188/200 - Loss: 140.360
Iter 189/200 - Loss: 136.488
Iter 190/200 - Loss: 201.296
Iter 191/200 - Loss: 163.410
Iter 192/200 - Loss: 174.225
Iter 193/200 - Loss: 218.408
Iter 194/200 - Loss: 178.131
Iter 195/200 - Loss: 162.437
Iter 196/200 - Loss: 145.230
Iter 197/200 - Loss: 151.984
Iter 198/200 - Loss: 121.274
Iter 199/200 - Loss: 137.730
Iter 200/200 - Loss: 145.839
CPU times: user 18.3 s, sys: 19.7 s, total: 38 s
Wall time: 5.44 s
In [5]:
# Set model and likelihood into eval mode
model.eval()
likelihood.eval()

# Initialize axes
f, ax = plt.subplots(1, 1, figsize=(4, 3))

with torch.no_grad():
    test_x = torch.linspace(0, 1, 101)
    predictions = likelihood(model(test_x))

ax.plot(train_x.numpy(), train_y.numpy(), 'k*')
pred_labels = predictions.mean.ge(0.5).float().mul(2).sub(1)
ax.plot(test_x.data.numpy(), pred_labels.numpy(), 'b')
ax.set_ylim([-3, 3])
ax.legend(['Observed Data', 'Mean', 'Confidence'])
Out[5]:
<matplotlib.legend.Legend at 0x7f0bec141080>
../../_images/examples_06_Scalable_GP_Classification_1D_KISSGP_Classification_1D_5_1.png
In [ ]: