Exact GP Regression with Multiple GPUs

Introduction

In this notebook, we’ll demonstrate training exact GPs on large datasets by distributing the kernel matrix across multiple GPUs, for additional parallelism.

NOTE: Kernel partitioning (another memory-saving mechanism introduced in https://arxiv.org/abs/1903.08114) is no longer supported for multiple GPUs. If your kernel matrix is too big to fit on your available GPUs, please use the GPyTorch KeOps integration for kernel partitioning.

We’ll be using the protein dataset, which has about 37000 training examples. The techniques in this notebook can be applied to much larger datasets, but the training time required will depend on the computational resources you have available: the number of GPUs available has a significant effect on training time.

[1]:
import math
import torch
import gpytorch
import sys
from matplotlib import pyplot as plt
sys.path.append('../')
from LBFGS import FullBatchLBFGS

%matplotlib inline
%load_ext autoreload
%autoreload 2

We will be using the Protein UCI dataset which contains a total of 40000+ data points. The next cell will download this dataset from a Google drive and load it.

[2]:
import os
import urllib.request
from scipy.io import loadmat
dataset = 'protein'
if not os.path.isfile(f'../{dataset}.mat'):
    print(f'Downloading \'{dataset}\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1nRb8e7qooozXkNghC5eQS0JeywSXGX2S',
                               f'../{dataset}.mat')

data = torch.Tensor(loadmat(f'../{dataset}.mat')['data'])

Normalization and train/test Splits

In the next cell, we split the data 80/20 as train and test, and do some basic z-score feature normalization.

[3]:
import numpy as np

N = data.shape[0]
# make train/val/test
n_train = int(0.8 * N)
train_x, train_y = data[:n_train, :-1], data[:n_train, -1]
test_x, test_y = data[n_train:, :-1], data[n_train:, -1]

# normalize features
mean = train_x.mean(dim=-2, keepdim=True)
std = train_x.std(dim=-2, keepdim=True) + 1e-6 # prevent dividing by 0
train_x = (train_x - mean) / std
test_x = (test_x - mean) / std

# normalize labels
mean, std = train_y.mean(),train_y.std()
train_y = (train_y - mean) / std
test_y = (test_y - mean) / std

# make continguous
train_x, train_y = train_x.contiguous(), train_y.contiguous()
test_x, test_y = test_x.contiguous(), test_y.contiguous()

output_device = torch.device('cuda:0')

train_x, train_y = train_x.to(output_device), train_y.to(output_device)
test_x, test_y = test_x.to(output_device), test_y.to(output_device)

How many GPUs do you want to use?

In the next cell, specify the n_devices variable to be the number of GPUs you’d like to use. By default, we will use all devices available to us.

[4]:
n_devices = torch.cuda.device_count()
print('Planning to run on {} GPUs.'.format(n_devices))
Planning to run on 2 GPUs.

In the next cell we define our GP model and training code. For this notebook, the only thing different from the Simple GP tutorials is the use of the MultiDeviceKernel to wrap the base covariance module. This allows for the use of multiple GPUs behind the scenes.

[5]:
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, n_devices):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        base_covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

        self.covar_module = gpytorch.kernels.MultiDeviceKernel(
            base_covar_module, device_ids=range(n_devices),
            output_device=output_device
        )

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train(train_x,
          train_y,
          n_devices,
          output_device,
          preconditioner_size,
          n_training_iter,
):
    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(output_device)
    model = ExactGPModel(train_x, train_y, likelihood, n_devices).to(output_device)
    model.train()
    likelihood.train()

    optimizer = FullBatchLBFGS(model.parameters(), lr=0.1)
    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)


    with gpytorch.settings.max_preconditioner_size(preconditioner_size):

        def closure():
            optimizer.zero_grad()
            output = model(train_x)
            loss = -mll(output, train_y)
            return loss

        loss = closure()
        loss.backward()

        for i in range(n_training_iter):
            options = {'closure': closure, 'current_loss': loss, 'max_ls': 10}
            loss, _, _, _, _, _, _, fail = optimizer.step(options)

            print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
                i + 1, n_training_iter, loss.item(),
                model.covar_module.module.base_kernel.lengthscale.item(),
                model.likelihood.noise.item()
            ))

            if fail:
                print('Convergence reached!')
                break

    print(f"Finished training on {train_x.size(0)} data points using {n_devices} GPUs.")
    return model, likelihood

Training

[7]:
model, likelihood = train(train_x, train_y,
                          n_devices=n_devices, output_device=output_device,
                          preconditioner_size=100,
                          n_training_iter=20)
Iter 1/20 - Loss: 0.897   lengthscale: 0.486   noise: 0.253
Iter 2/20 - Loss: 0.892   lengthscale: 0.354   noise: 0.144
Iter 3/20 - Loss: 0.859   lengthscale: 0.305   noise: 0.125
Iter 4/20 - Loss: 0.859   lengthscale: 0.292   noise: 0.116
Iter 5/20 - Loss: 0.862   lengthscale: 0.292   noise: 0.116
Convergence reached!
Finished training on 36584 data points using 2 GPUs.

Computing test time caches

[9]:
# Get into evaluation (predictive posterior) mode
model.eval()
likelihood.eval()

with torch.no_grad(), gpytorch.settings.fast_pred_var():
    # Make predictions on a small number of test points to get the test time caches computed
    latent_pred = model(test_x[:2, :])
    del latent_pred  # We don't care about these predictions, we really just want the caches.

Testing: Computing predictions

[11]:
with torch.no_grad(), gpytorch.settings.fast_pred_var():
    %time latent_pred = model(test_x)

test_rmse = torch.sqrt(torch.mean(torch.pow(latent_pred.mean - test_y, 2)))
print(f"Test RMSE: {test_rmse.item()}")
CPU times: user 1.1 s, sys: 728 ms, total: 1.83 s
Wall time: 1.88 s
Test RMSE: 0.551821768283844
[10]:

[ ]: