GPyTorch regression with derivative information in 2d

Introduction

In this notebook, we show how to train a GP regression model in GPyTorch of a 2-dimensional function given function values and derivative observations. We consider modeling the Franke function where the values and derivatives are contaminated with independent \(\mathcal{N}(0, 0.5)\) distributed noise.

[1]:
import torch
import gpytorch
import math
from matplotlib import cm
from matplotlib import pyplot as plt
import numpy as np

%matplotlib inline
%load_ext autoreload
%autoreload 2

Franke function

The following is a vectorized implementation of the 2-dimensional Franke function (https://www.sfu.ca/~ssurjano/franke2d.html)

[2]:
def franke(X, Y):
    term1 = .75*torch.exp(-((9*X - 2).pow(2) + (9*Y - 2).pow(2))/4)
    term2 = .75*torch.exp(-((9*X + 1).pow(2))/49 - (9*Y + 1)/10)
    term3 = .5*torch.exp(-((9*X - 7).pow(2) + (9*Y - 3).pow(2))/4)
    term4 = .2*torch.exp(-(9*X - 4).pow(2) - (9*Y - 7).pow(2))

    f = term1 + term2 + term3 - term4
    dfx = -2*(9*X - 2)*9/4 * term1 - 2*(9*X + 1)*9/49 * term2 + \
          -2*(9*X - 7)*9/4 * term3 + 2*(9*X - 4)*9 * term4
    dfy = -2*(9*Y - 2)*9/4 * term1 - 9/10 * term2 + \
          -2*(9*Y - 3)*9/4 * term3 + 2*(9*Y - 7)*9 * term4

    return f, dfx, dfy

Setting up the training data

We use a grid with 100 points in \([0,1] \times [0,1]\) with 10 uniformly distributed points per dimension.

[3]:
xv, yv = torch.meshgrid([torch.linspace(0, 1, 10), torch.linspace(0, 1, 10)])
train_x = torch.cat((
    xv.contiguous().view(xv.numel(), 1),
    yv.contiguous().view(yv.numel(), 1)),
    dim=1
)

f, dfx, dfy = franke(train_x[:, 0], train_x[:, 1])
train_y = torch.stack([f, dfx, dfy], -1).squeeze(1)

train_y += 0.05 * torch.randn(train_y.size()) # Add noise to both values and gradients

Setting up the model

A GP prior on the function values implies a multi-output GP prior on the function values and the partial derivatives, see 9.4 in http://www.gaussianprocess.org/gpml/chapters/RW9.pdf for more details. This allows using a MultitaskMultivariateNormal and MultitaskGaussianLikelihood to train a GP model from both function values and gradients. The resulting RBF kernel that models the covariance between the values and partial derivatives has been implemented in RBFKernelGrad and the extension of a constant mean is implemented in ConstantMeanGrad.

[4]:
class GPModelWithDerivatives(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPModelWithDerivatives, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMeanGrad()
        self.base_kernel = gpytorch.kernels.PolynomialKernelGrad(power=2)
        self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=3)  # Value + x-derivative + y-derivative
model = GPModelWithDerivatives(train_x, train_y, likelihood)

Training the model

The model training is similar to training a standard GP regression model

[5]:
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam([
    {'params': model.parameters()},  # Includes GaussianLikelihood parameters
], lr=0.05)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

n_iter = 100
for i in range(n_iter):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print("Iter %d/%d - Loss: %.3f   offset: %.3f   noise: %.3f" % (
        i + 1, n_iter, loss.item(),
        model.covar_module.base_kernel.offset.item(),
        model.likelihood.noise.item()
    ))
    optimizer.step()
Iter 1/100 - Loss: 133.388   offset: 0.693   noise: 0.693
Iter 2/100 - Loss: 133.050   offset: 0.668   noise: 0.669
Iter 3/100 - Loss: 132.400   offset: 0.645   noise: 0.645
Iter 4/100 - Loss: 130.977   offset: 0.622   noise: 0.621
Iter 5/100 - Loss: 129.276   offset: 0.600   noise: 0.598
Iter 6/100 - Loss: 128.238   offset: 0.578   noise: 0.576
Iter 7/100 - Loss: 128.037   offset: 0.558   noise: 0.555
Iter 8/100 - Loss: 127.009   offset: 0.537   noise: 0.534
Iter 9/100 - Loss: 125.708   offset: 0.517   noise: 0.514
Iter 10/100 - Loss: 125.870   offset: 0.498   noise: 0.495
Iter 11/100 - Loss: 124.483   offset: 0.479   noise: 0.476
Iter 12/100 - Loss: 123.123   offset: 0.460   noise: 0.458
Iter 13/100 - Loss: 123.445   offset: 0.442   noise: 0.441
Iter 14/100 - Loss: 121.856   offset: 0.425   noise: 0.424
Iter 15/100 - Loss: 122.067   offset: 0.408   noise: 0.409
Iter 16/100 - Loss: 121.626   offset: 0.392   noise: 0.394
Iter 17/100 - Loss: 120.576   offset: 0.376   noise: 0.379
Iter 18/100 - Loss: 120.080   offset: 0.360   noise: 0.366
Iter 19/100 - Loss: 119.255   offset: 0.345   noise: 0.352
Iter 20/100 - Loss: 117.712   offset: 0.330   noise: 0.340
Iter 21/100 - Loss: 117.911   offset: 0.316   noise: 0.328
Iter 22/100 - Loss: 118.118   offset: 0.303   noise: 0.317
Iter 23/100 - Loss: 117.196   offset: 0.290   noise: 0.306
Iter 24/100 - Loss: 115.636   offset: 0.278   noise: 0.296
Iter 25/100 - Loss: 115.403   offset: 0.265   noise: 0.285
Iter 26/100 - Loss: 114.443   offset: 0.254   noise: 0.276
Iter 27/100 - Loss: 114.576   offset: 0.244   noise: 0.266
Iter 28/100 - Loss: 113.681   offset: 0.233   noise: 0.257
Iter 29/100 - Loss: 112.765   offset: 0.223   noise: 0.248
Iter 30/100 - Loss: 112.422   offset: 0.214   noise: 0.240
Iter 31/100 - Loss: 110.949   offset: 0.205   noise: 0.231
Iter 32/100 - Loss: 110.494   offset: 0.196   noise: 0.223
Iter 33/100 - Loss: 110.101   offset: 0.189   noise: 0.214
Iter 34/100 - Loss: 109.204   offset: 0.181   noise: 0.206
Iter 35/100 - Loss: 108.897   offset: 0.174   noise: 0.198
Iter 36/100 - Loss: 107.919   offset: 0.167   noise: 0.191
Iter 37/100 - Loss: 107.100   offset: 0.161   noise: 0.183
Iter 38/100 - Loss: 106.777   offset: 0.155   noise: 0.175
Iter 39/100 - Loss: 106.262   offset: 0.149   noise: 0.168
Iter 40/100 - Loss: 105.564   offset: 0.144   noise: 0.161
Iter 41/100 - Loss: 106.274   offset: 0.140   noise: 0.154
Iter 42/100 - Loss: 103.915   offset: 0.134   noise: 0.147
Iter 43/100 - Loss: 102.595   offset: 0.130   noise: 0.141
Iter 44/100 - Loss: 102.076   offset: 0.126   noise: 0.134
Iter 45/100 - Loss: 101.620   offset: 0.121   noise: 0.128
Iter 46/100 - Loss: 101.127   offset: 0.118   noise: 0.122
Iter 47/100 - Loss: 99.758   offset: 0.114   noise: 0.116
Iter 48/100 - Loss: 99.445   offset: 0.110   noise: 0.111
Iter 49/100 - Loss: 99.830   offset: 0.107   noise: 0.106
Iter 50/100 - Loss: 99.104   offset: 0.103   noise: 0.101
Iter 51/100 - Loss: 97.977   offset: 0.100   noise: 0.096
Iter 52/100 - Loss: 97.119   offset: 0.096   noise: 0.091
Iter 53/100 - Loss: 96.406   offset: 0.093   noise: 0.087
Iter 54/100 - Loss: 96.264   offset: 0.091   noise: 0.082
Iter 55/100 - Loss: 94.711   offset: 0.088   noise: 0.078
Iter 56/100 - Loss: 94.982   offset: 0.086   noise: 0.075
Iter 57/100 - Loss: 93.965   offset: 0.084   noise: 0.071
Iter 58/100 - Loss: 92.833   offset: 0.082   noise: 0.068
Iter 59/100 - Loss: 92.222   offset: 0.081   noise: 0.064
Iter 60/100 - Loss: 92.347   offset: 0.080   noise: 0.061
Iter 61/100 - Loss: 91.647   offset: 0.078   noise: 0.058
Iter 62/100 - Loss: 90.827   offset: 0.077   noise: 0.056
Iter 63/100 - Loss: 89.901   offset: 0.076   noise: 0.053
Iter 64/100 - Loss: 89.392   offset: 0.075   noise: 0.050
Iter 65/100 - Loss: 88.490   offset: 0.075   noise: 0.048
Iter 66/100 - Loss: 88.222   offset: 0.074   noise: 0.046
Iter 67/100 - Loss: 88.573   offset: 0.073   noise: 0.044
Iter 68/100 - Loss: 87.199   offset: 0.073   noise: 0.042
Iter 69/100 - Loss: 87.355   offset: 0.073   noise: 0.040
Iter 70/100 - Loss: 86.861   offset: 0.072   noise: 0.038
Iter 71/100 - Loss: 85.874   offset: 0.072   noise: 0.036
Iter 72/100 - Loss: 85.730   offset: 0.072   noise: 0.035
Iter 73/100 - Loss: 84.800   offset: 0.072   noise: 0.033
Iter 74/100 - Loss: 84.612   offset: 0.072   noise: 0.032
Iter 75/100 - Loss: 84.215   offset: 0.072   noise: 0.031
Iter 76/100 - Loss: 84.063   offset: 0.072   noise: 0.029
Iter 77/100 - Loss: 84.007   offset: 0.073   noise: 0.028
Iter 78/100 - Loss: 83.099   offset: 0.074   noise: 0.027
Iter 79/100 - Loss: 83.234   offset: 0.075   noise: 0.026
Iter 80/100 - Loss: 82.187   offset: 0.076   noise: 0.025
Iter 81/100 - Loss: 82.020   offset: 0.077   noise: 0.024
Iter 82/100 - Loss: 82.789   offset: 0.077   noise: 0.023
Iter 83/100 - Loss: 82.131   offset: 0.078   noise: 0.022
Iter 84/100 - Loss: 81.481   offset: 0.079   noise: 0.021
Iter 85/100 - Loss: 81.546   offset: 0.079   noise: 0.021
Iter 86/100 - Loss: 80.771   offset: 0.080   noise: 0.020
Iter 87/100 - Loss: 80.939   offset: 0.080   noise: 0.019
Iter 88/100 - Loss: 80.208   offset: 0.081   noise: 0.019
Iter 89/100 - Loss: 81.250   offset: 0.082   noise: 0.018
Iter 90/100 - Loss: 80.921   offset: 0.083   noise: 0.018
Iter 91/100 - Loss: 81.127   offset: 0.084   noise: 0.017
Iter 92/100 - Loss: 79.736   offset: 0.085   noise: 0.017
Iter 93/100 - Loss: 81.718   offset: 0.085   noise: 0.016
Iter 94/100 - Loss: 80.228   offset: 0.086   noise: 0.016
Iter 95/100 - Loss: 79.738   offset: 0.086   noise: 0.015
Iter 96/100 - Loss: 80.034   offset: 0.087   noise: 0.015
Iter 97/100 - Loss: 80.193   offset: 0.089   noise: 0.015
Iter 98/100 - Loss: 80.248   offset: 0.090   noise: 0.014
Iter 99/100 - Loss: 79.778   offset: 0.090   noise: 0.014
Iter 100/100 - Loss: 79.247   offset: 0.091   noise: 0.014

Making predictions with the model

Model predictions are also similar to GP regression with only function values, but we need more CG iterations to get accurate estimates of the predictive variance

[6]:
# Set into eval mode
model.eval()
likelihood.eval()

# Initialize plots
fig, ax = plt.subplots(2, 3, figsize=(14, 10))

# Test points
n1, n2 = 50, 50
xv, yv = torch.meshgrid([torch.linspace(0, 1, n1), torch.linspace(0, 1, n2)])
f, dfx, dfy = franke(xv, yv)

# Make predictions
with torch.no_grad(), gpytorch.settings.fast_computations(log_prob=False, covar_root_decomposition=False):
    test_x = torch.stack([xv.reshape(n1*n2, 1), yv.reshape(n1*n2, 1)], -1).squeeze(1)
    predictions = likelihood(model(test_x))
    mean = predictions.mean

extent = (xv.min(), xv.max(), yv.max(), yv.min())
ax[0, 0].imshow(f, extent=extent, cmap=cm.jet)
ax[0, 0].set_title('True values')
ax[0, 1].imshow(dfx, extent=extent, cmap=cm.jet)
ax[0, 1].set_title('True x-derivatives')
ax[0, 2].imshow(dfy, extent=extent, cmap=cm.jet)
ax[0, 2].set_title('True y-derivatives')

ax[1, 0].imshow(mean[:, 0].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 0].set_title('Predicted values')
ax[1, 1].imshow(mean[:, 1].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 1].set_title('Predicted x-derivatives')
ax[1, 2].imshow(mean[:, 2].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 2].set_title('Predicted y-derivatives')

None
../../_images/examples_10_GP_Regression_Derivative_Information_Simple_GP_Regression_Derivative_Information_2d_Polynomial_kernel_11_0.png
[ ]: