GPyTorch regression with derivative information in 2d

Introduction

In this notebook, we show how to train a GP regression model in GPyTorch of a 2-dimensional function given function values and derivative observations. We consider modeling the Franke function where the values and derivatives are contaminated with independent \(\mathcal{N}(0, 0.5)\) distributed noise.

[1]:
import torch
import gpytorch
import math
from matplotlib import cm
from matplotlib import pyplot as plt
import numpy as np

%matplotlib inline
%load_ext autoreload
%autoreload 2

Franke function

The following is a vectorized implementation of the 2-dimensional Franke function (https://www.sfu.ca/~ssurjano/franke2d.html)

[2]:
def franke(X, Y):
    term1 = .75*torch.exp(-((9*X - 2).pow(2) + (9*Y - 2).pow(2))/4)
    term2 = .75*torch.exp(-((9*X + 1).pow(2))/49 - (9*Y + 1)/10)
    term3 = .5*torch.exp(-((9*X - 7).pow(2) + (9*Y - 3).pow(2))/4)
    term4 = .2*torch.exp(-(9*X - 4).pow(2) - (9*Y - 7).pow(2))

    f = term1 + term2 + term3 - term4
    dfx = -2*(9*X - 2)*9/4 * term1 - 2*(9*X + 1)*9/49 * term2 + \
          -2*(9*X - 7)*9/4 * term3 + 2*(9*X - 4)*9 * term4
    dfy = -2*(9*Y - 2)*9/4 * term1 - 9/10 * term2 + \
          -2*(9*Y - 3)*9/4 * term3 + 2*(9*Y - 7)*9 * term4

    return f, dfx, dfy

Setting up the training data

We use a grid with 100 points in \([0,1] \times [0,1]\) with 10 uniformly distributed points per dimension.

[3]:
xv, yv = torch.meshgrid([torch.linspace(0, 1, 10), torch.linspace(0, 1, 10)])
train_x = torch.cat((
    xv.contiguous().view(xv.numel(), 1),
    yv.contiguous().view(yv.numel(), 1)),
    dim=1
)

f, dfx, dfy = franke(train_x[:, 0], train_x[:, 1])
train_y = torch.stack([f, dfx, dfy], -1).squeeze(1)

train_y += 0.05 * torch.randn(train_y.size()) # Add noise to both values and gradients

Setting up the model

A GP prior on the function values implies a multi-output GP prior on the function values and the partial derivatives, see 9.4 in http://www.gaussianprocess.org/gpml/chapters/RW9.pdf for more details. This allows using a MultitaskMultivariateNormal and MultitaskGaussianLikelihood to train a GP model from both function values and gradients. The resulting RBF kernel that models the covariance between the values and partial derivatives has been implemented in RBFKernelGrad and the extension of a constant mean is implemented in ConstantMeanGrad.

[4]:
class GPModelWithDerivatives(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPModelWithDerivatives, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMeanGrad()
        self.base_kernel = gpytorch.kernels.RBFKernelGrad(ard_num_dims=2)
        self.covar_module = gpytorch.kernels.ScaleKernel(self.base_kernel)

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultitaskMultivariateNormal(mean_x, covar_x)

likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=3)  # Value + x-derivative + y-derivative
model = GPModelWithDerivatives(train_x, train_y, likelihood)

Training the model

The model training is similar to training a standard GP regression model

[5]:
# Find optimal model hyperparameters
model.train()
likelihood.train()

# Use the adam optimizer
optimizer = torch.optim.Adam([
    {'params': model.parameters()},  # Includes GaussianLikelihood parameters
], lr=0.05)

# "Loss" for GPs - the marginal log likelihood
mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

n_iter = 100
for i in range(n_iter):
    optimizer.zero_grad()
    output = model(train_x)
    loss = -mll(output, train_y)
    loss.backward()
    print("Iter %d/%d - Loss: %.3f   lengthscales: %.3f, %.3f   noise: %.3f" % (
        i + 1, n_iter, loss.item(),
        model.covar_module.base_kernel.lengthscale.squeeze()[0],
        model.covar_module.base_kernel.lengthscale.squeeze()[1],
        model.likelihood.noise.item()
    ))
    optimizer.step()
Iter 1/100 - Loss: 128.629   lengthscales: 0.693, 0.693   noise: 0.693
Iter 2/100 - Loss: 127.182   lengthscales: 0.668, 0.668   noise: 0.668
Iter 3/100 - Loss: 125.838   lengthscales: 0.644, 0.645   noise: 0.644
Iter 4/100 - Loss: 123.982   lengthscales: 0.624, 0.621   noise: 0.621
Iter 5/100 - Loss: 122.799   lengthscales: 0.604, 0.598   noise: 0.598
Iter 6/100 - Loss: 120.909   lengthscales: 0.583, 0.576   noise: 0.576
Iter 7/100 - Loss: 119.255   lengthscales: 0.562, 0.555   noise: 0.554
Iter 8/100 - Loss: 117.506   lengthscales: 0.542, 0.534   noise: 0.533
Iter 9/100 - Loss: 116.083   lengthscales: 0.522, 0.513   noise: 0.513
Iter 10/100 - Loss: 113.978   lengthscales: 0.502, 0.493   noise: 0.493
Iter 11/100 - Loss: 112.242   lengthscales: 0.482, 0.474   noise: 0.473
Iter 12/100 - Loss: 110.389   lengthscales: 0.463, 0.455   noise: 0.455
Iter 13/100 - Loss: 108.644   lengthscales: 0.444, 0.436   noise: 0.436
Iter 14/100 - Loss: 107.660   lengthscales: 0.426, 0.418   noise: 0.419
Iter 15/100 - Loss: 104.480   lengthscales: 0.408, 0.402   noise: 0.401
Iter 16/100 - Loss: 103.058   lengthscales: 0.391, 0.387   noise: 0.385
Iter 17/100 - Loss: 101.174   lengthscales: 0.374, 0.373   noise: 0.369
Iter 18/100 - Loss: 98.379   lengthscales: 0.358, 0.361   noise: 0.353
Iter 19/100 - Loss: 96.482   lengthscales: 0.343, 0.352   noise: 0.338
Iter 20/100 - Loss: 95.282   lengthscales: 0.327, 0.344   noise: 0.323
Iter 21/100 - Loss: 92.911   lengthscales: 0.313, 0.339   noise: 0.309
Iter 22/100 - Loss: 89.532   lengthscales: 0.300, 0.335   noise: 0.295
Iter 23/100 - Loss: 89.324   lengthscales: 0.288, 0.332   noise: 0.282
Iter 24/100 - Loss: 86.490   lengthscales: 0.279, 0.329   noise: 0.269
Iter 25/100 - Loss: 85.546   lengthscales: 0.272, 0.328   noise: 0.257
Iter 26/100 - Loss: 83.578   lengthscales: 0.268, 0.327   noise: 0.245
Iter 27/100 - Loss: 81.732   lengthscales: 0.265, 0.326   noise: 0.234
Iter 28/100 - Loss: 79.472   lengthscales: 0.265, 0.326   noise: 0.223
Iter 29/100 - Loss: 77.669   lengthscales: 0.267, 0.327   noise: 0.212
Iter 30/100 - Loss: 75.215   lengthscales: 0.269, 0.329   noise: 0.202
Iter 31/100 - Loss: 73.676   lengthscales: 0.272, 0.329   noise: 0.193
Iter 32/100 - Loss: 70.514   lengthscales: 0.276, 0.328   noise: 0.183
Iter 33/100 - Loss: 69.765   lengthscales: 0.280, 0.325   noise: 0.175
Iter 34/100 - Loss: 68.525   lengthscales: 0.284, 0.320   noise: 0.166
Iter 35/100 - Loss: 66.181   lengthscales: 0.287, 0.314   noise: 0.158
Iter 36/100 - Loss: 62.446   lengthscales: 0.288, 0.307   noise: 0.150
Iter 37/100 - Loss: 62.009   lengthscales: 0.287, 0.299   noise: 0.143
Iter 38/100 - Loss: 58.204   lengthscales: 0.284, 0.290   noise: 0.136
Iter 39/100 - Loss: 57.167   lengthscales: 0.280, 0.281   noise: 0.130
Iter 40/100 - Loss: 54.072   lengthscales: 0.274, 0.271   noise: 0.123
Iter 41/100 - Loss: 51.696   lengthscales: 0.268, 0.261   noise: 0.117
Iter 42/100 - Loss: 49.792   lengthscales: 0.261, 0.253   noise: 0.111
Iter 43/100 - Loss: 46.250   lengthscales: 0.255, 0.246   noise: 0.106
Iter 44/100 - Loss: 47.110   lengthscales: 0.250, 0.241   noise: 0.101
Iter 45/100 - Loss: 45.541   lengthscales: 0.248, 0.237   noise: 0.096
Iter 46/100 - Loss: 41.711   lengthscales: 0.246, 0.237   noise: 0.091
Iter 47/100 - Loss: 40.852   lengthscales: 0.245, 0.237   noise: 0.086
Iter 48/100 - Loss: 39.588   lengthscales: 0.244, 0.239   noise: 0.082
Iter 49/100 - Loss: 36.817   lengthscales: 0.244, 0.241   noise: 0.078
Iter 50/100 - Loss: 34.773   lengthscales: 0.244, 0.244   noise: 0.074
Iter 51/100 - Loss: 31.050   lengthscales: 0.243, 0.247   noise: 0.070
Iter 52/100 - Loss: 28.448   lengthscales: 0.242, 0.248   noise: 0.067
Iter 53/100 - Loss: 29.796   lengthscales: 0.241, 0.246   noise: 0.063
Iter 54/100 - Loss: 25.501   lengthscales: 0.239, 0.243   noise: 0.060
Iter 55/100 - Loss: 28.542   lengthscales: 0.237, 0.238   noise: 0.057
Iter 56/100 - Loss: 23.089   lengthscales: 0.236, 0.231   noise: 0.054
Iter 57/100 - Loss: 19.792   lengthscales: 0.235, 0.225   noise: 0.051
Iter 58/100 - Loss: 20.285   lengthscales: 0.235, 0.219   noise: 0.049
Iter 59/100 - Loss: 16.047   lengthscales: 0.234, 0.214   noise: 0.046
Iter 60/100 - Loss: 15.160   lengthscales: 0.234, 0.211   noise: 0.044
Iter 61/100 - Loss: 13.038   lengthscales: 0.232, 0.209   noise: 0.042
Iter 62/100 - Loss: 13.928   lengthscales: 0.230, 0.209   noise: 0.040
Iter 63/100 - Loss: 9.312   lengthscales: 0.227, 0.210   noise: 0.038
Iter 64/100 - Loss: 7.950   lengthscales: 0.223, 0.212   noise: 0.036
Iter 65/100 - Loss: 3.461   lengthscales: 0.220, 0.215   noise: 0.034
Iter 66/100 - Loss: 5.609   lengthscales: 0.217, 0.217   noise: 0.033
Iter 67/100 - Loss: 2.204   lengthscales: 0.214, 0.218   noise: 0.031
Iter 68/100 - Loss: 0.597   lengthscales: 0.212, 0.219   noise: 0.029
Iter 69/100 - Loss: -1.111   lengthscales: 0.211, 0.217   noise: 0.028
Iter 70/100 - Loss: -2.389   lengthscales: 0.209, 0.214   noise: 0.027
Iter 71/100 - Loss: -3.256   lengthscales: 0.208, 0.210   noise: 0.025
Iter 72/100 - Loss: -4.180   lengthscales: 0.209, 0.207   noise: 0.024
Iter 73/100 - Loss: -6.345   lengthscales: 0.209, 0.205   noise: 0.023
Iter 74/100 - Loss: -10.216   lengthscales: 0.210, 0.204   noise: 0.022
Iter 75/100 - Loss: -11.749   lengthscales: 0.209, 0.204   noise: 0.021
Iter 76/100 - Loss: -10.651   lengthscales: 0.208, 0.204   noise: 0.020
Iter 77/100 - Loss: -12.092   lengthscales: 0.207, 0.205   noise: 0.019
Iter 78/100 - Loss: -14.908   lengthscales: 0.204, 0.206   noise: 0.018
Iter 79/100 - Loss: -16.482   lengthscales: 0.202, 0.208   noise: 0.017
Iter 80/100 - Loss: -17.962   lengthscales: 0.199, 0.207   noise: 0.016
Iter 81/100 - Loss: -23.044   lengthscales: 0.198, 0.207   noise: 0.016
Iter 82/100 - Loss: -20.867   lengthscales: 0.196, 0.205   noise: 0.015
Iter 83/100 - Loss: -20.908   lengthscales: 0.195, 0.203   noise: 0.014
Iter 84/100 - Loss: -25.210   lengthscales: 0.193, 0.201   noise: 0.013
Iter 85/100 - Loss: -24.521   lengthscales: 0.193, 0.199   noise: 0.013
Iter 86/100 - Loss: -25.571   lengthscales: 0.193, 0.199   noise: 0.012
Iter 87/100 - Loss: -26.477   lengthscales: 0.194, 0.199   noise: 0.012
Iter 88/100 - Loss: -26.940   lengthscales: 0.195, 0.200   noise: 0.011
Iter 89/100 - Loss: -27.446   lengthscales: 0.196, 0.199   noise: 0.011
Iter 90/100 - Loss: -30.484   lengthscales: 0.196, 0.198   noise: 0.010
Iter 91/100 - Loss: -29.450   lengthscales: 0.194, 0.196   noise: 0.010
Iter 92/100 - Loss: -28.761   lengthscales: 0.192, 0.198   noise: 0.009
Iter 93/100 - Loss: -34.818   lengthscales: 0.189, 0.200   noise: 0.009
Iter 94/100 - Loss: -39.531   lengthscales: 0.186, 0.203   noise: 0.009
Iter 95/100 - Loss: -38.291   lengthscales: 0.184, 0.202   noise: 0.008
Iter 96/100 - Loss: -38.961   lengthscales: 0.182, 0.200   noise: 0.008
Iter 97/100 - Loss: -41.103   lengthscales: 0.180, 0.197   noise: 0.007
Iter 98/100 - Loss: -42.563   lengthscales: 0.179, 0.194   noise: 0.007
Iter 99/100 - Loss: -42.571   lengthscales: 0.179, 0.191   noise: 0.007
Iter 100/100 - Loss: -37.692   lengthscales: 0.179, 0.191   noise: 0.007

Making predictions with the model

Model predictions are also similar to GP regression with only function values, but we need more CG iterations to get accurate estimates of the predictive variance

[6]:
# Set into eval mode
model.eval()
likelihood.eval()

# Initialize plots
fig, ax = plt.subplots(2, 3, figsize=(14, 10))

# Test points
n1, n2 = 50, 50
xv, yv = torch.meshgrid([torch.linspace(0, 1, n1), torch.linspace(0, 1, n2)])
f, dfx, dfy = franke(xv, yv)

# Make predictions
with torch.no_grad(), gpytorch.settings.fast_computations(log_prob=False, covar_root_decomposition=False):
    test_x = torch.stack([xv.reshape(n1*n2, 1), yv.reshape(n1*n2, 1)], -1).squeeze(1)
    predictions = likelihood(model(test_x))
    mean = predictions.mean

extent = (xv.min(), xv.max(), yv.max(), yv.min())
ax[0, 0].imshow(f, extent=extent, cmap=cm.jet)
ax[0, 0].set_title('True values')
ax[0, 1].imshow(dfx, extent=extent, cmap=cm.jet)
ax[0, 1].set_title('True x-derivatives')
ax[0, 2].imshow(dfy, extent=extent, cmap=cm.jet)
ax[0, 2].set_title('True y-derivatives')

ax[1, 0].imshow(mean[:, 0].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 0].set_title('Predicted values')
ax[1, 1].imshow(mean[:, 1].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 1].set_title('Predicted x-derivatives')
ax[1, 2].imshow(mean[:, 2].detach().numpy().reshape(n1, n2), extent=extent, cmap=cm.jet)
ax[1, 2].set_title('Predicted y-derivatives')

None
../../_images/examples_10_GP_Regression_Derivative_Information_Simple_GP_Regression_Derivative_Information_2d_11_0.png
[ ]: