Latent Function Inference with Pyro + GPyTorch (Low-Level Interface)

Overview

In this example, we will give an overview of the low-level Pyro-GPyTorch integration. The low-level interface makes it possible to write GP models in a Pyro-style – i.e. defining your own model and guide functions.

These are the key differences between the high-level and low-level interface:

High level interface

  • Base class is gpytorch.models.PyroGP.

  • GPyTorch automatically defines the model and guide functions for Pyro.

  • Best used when prediction is the primary goal

Low level interface

  • Base class is gpytorch.models.ApproximateGP.

  • User defines the model and guide functions for Pyro.

  • Best used when inference is the primary goal

[1]:
import math
import torch
import pyro
import tqdm
import gpytorch
import numpy as np
from matplotlib import pyplot as plt

%matplotlib inline

This example uses a GP to infer a latent function \(\lambda(x)\), which parameterises the exponential distribution:

\[y \sim \text{Exponential} (\lambda),\]

where:

\[\lambda = \exp(f) \in (0,+\infty)\]

is a GP link function, which transforms the latent gaussian process variable:

\[f \sim GP \in (-\infty,+\infty).\]

In other words, given inputs \(X\) and observations \(Y\) drawn from exponential distribution with \(\lambda = \lambda(X)\), we want to find \(\lambda(X)\).

[2]:
# Here we specify a 'true' latent function lambda
scale = lambda x: np.sin(2 * math.pi * x) + 1

# Generate synthetic data
# here we generate some synthetic samples
NSamp = 100

X = np.linspace(0, 1, 100)

fig, (lambdaf, samples) = plt.subplots(1, 2, figsize=(10, 3))

lambdaf.plot(X,scale(X))
lambdaf.set_xlabel('x')
lambdaf.set_ylabel('$\lambda$')
lambdaf.set_title('Latent function')

Y = np.zeros_like(X)
for i,x in enumerate(X):
    Y[i] = np.random.exponential(scale(x), 1)
samples.scatter(X,Y)
samples.set_xlabel('x')
samples.set_ylabel('y')
samples.set_title('Samples from exp. distrib.')
[2]:
Text(0.5, 1.0, 'Samples from exp. distrib.')
../../_images/examples_07_Pyro_Integration_Pyro_GPyTorch_Low_Level_3_1.png
[3]:
train_x = torch.tensor(X).float()
train_y = torch.tensor(Y).float()

Using the low-level Pyro/GPyTorch interface

The low-level iterface should look familiar if you’ve written Pyro models/guides before. We’ll use a gpytorch.models.ApproximateGP object to model the GP. To use the low-level interface, this object needs to define 3 functions:

  • forward(x) - which computes the prior GP mean and covariance at the supplied times.

  • guide(x) - which defines the approximate GP posterior.

  • model(x) - which does the following 3 things

    • Computes the GP prior at x

    • Converts GP function samples into scale function samples, using the link function defined above.

    • Sample from the observed distribution p(y | f). (This takes the place of a gpytorch Likelihood that we would’ve used in the high-level interface).

[4]:
class PVGPRegressionModel(gpytorch.models.ApproximateGP):
    def __init__(self, num_inducing=64, name_prefix="mixture_gp"):
        self.name_prefix = name_prefix

        # Define all the variational stuff
        inducing_points = torch.linspace(0, 1, num_inducing)
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points,
            gpytorch.variational.CholeskyVariationalDistribution(num_inducing_points=num_inducing)
        )

        # Standard initializtation
        super().__init__(variational_strategy)

        # Mean, covar, likelihood
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean = self.mean_module(x)
        covar = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean, covar)

    def guide(self, x, y):
        # Get q(f) - variational (guide) distribution of latent function
        function_dist = self.pyro_guide(x)

        # Use a plate here to mark conditional independencies
        with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
            # Sample from latent function distribution
            pyro.sample(self.name_prefix + ".f(x)", function_dist)

    def model(self, x, y):
        pyro.module(self.name_prefix + ".gp", self)

        # Get p(f) - prior distribution of latent function
        function_dist = self.pyro_model(x)

        # Use a plate here to mark conditional independencies
        with pyro.plate(self.name_prefix + ".data_plate", dim=-1):
            # Sample from latent function distribution
            function_samples = pyro.sample(self.name_prefix + ".f(x)", function_dist)

            # Use the link function to convert GP samples into scale samples
            scale_samples = function_samples.exp()

            # Sample from observed distribution
            return pyro.sample(
                self.name_prefix + ".y",
                pyro.distributions.Exponential(scale_samples.reciprocal()),  # rate = 1 / scale
                obs=y
            )
[5]:
model = PVGPRegressionModel()

Performing inference with Pyro

Unlike all the other examples in this library, PyroGP models use Pyro’s inference and optimization classes (rather than the classes provided by PyTorch).

If you are unfamiliar with Pyro’s inference tools, we recommend checking out the Pyro SVI tutorial.

[6]:
# this is for running the notebook in our testing framework
import os
smoke_test = ('CI' in os.environ)
num_iter = 2 if smoke_test else 200
num_particles = 1 if smoke_test else 256


def train():
    optimizer = pyro.optim.Adam({"lr": 0.1})
    elbo = pyro.infer.Trace_ELBO(num_particles=num_particles, vectorize_particles=True, retain_graph=True)
    svi = pyro.infer.SVI(model.model, model.guide, optimizer, elbo)

    model.train()
    iterator = tqdm.notebook.tqdm(range(num_iter))
    for i in iterator:
        model.zero_grad()
        loss = svi.step(train_x, train_y)
        iterator.set_postfix(loss=loss, lengthscale=model.covar_module.base_kernel.lengthscale.item())

%time train()

CPU times: user 41.1 s, sys: 2.78 s, total: 43.9 s
Wall time: 6.54 s

In this example, we are only performing inference over the GP latent function (and its associated hyperparameters). In later examples, we will see that this basic loop also performs inference over any additional latent variables that we define.

Making predictions

For some problems, we simply want to use Pyro to perform inference over latent variables. However, we can also use the models’ (approximate) predictive posterior distribution. Making predictions with a PyroGP model is exactly the same as for standard GPyTorch models.

[7]:
# Here's a quick helper function for getting smoothed percentile values from samples
def percentiles_from_samples(samples, percentiles=[0.05, 0.5, 0.95]):
    num_samples = samples.size(0)
    samples = samples.sort(dim=0)[0]

    # Get samples corresponding to percentile
    percentile_samples = [samples[int(num_samples * percentile)] for percentile in percentiles]

    # Smooth the samples
    kernel = torch.full((1, 1, 5), fill_value=0.2)
    percentiles_samples = [
        torch.nn.functional.conv1d(percentile_sample.view(1, 1, -1), kernel, padding=2).view(-1)
        for percentile_sample in percentile_samples
    ]

    return percentile_samples
[8]:
# define test set (optionally on GPU)
denser = 2 # make test set 2 times denser then the training set
test_x = torch.linspace(0, 1, denser * NSamp).float()#.cuda()

model.eval()
with torch.no_grad():
    output = model(test_x)

# Get E[exp(f)] via f_i ~ GP, 1/n \sum_{i=1}^{n} exp(f_i).
# Similarly get the 5th and 95th percentiles
samples = output(torch.Size([1000])).exp()
lower, mean, upper = percentiles_from_samples(samples)

# Draw some simulated y values
scale_sim = model(train_x)().exp()
y_sim = pyro.distributions.Exponential(scale_sim.reciprocal())()
[9]:
# visualize the result
fig, (func, samp) = plt.subplots(1, 2, figsize=(12, 3))
line, = func.plot(test_x, mean.detach().cpu().numpy(), label='GP prediction')
func.fill_between(
    test_x, lower.detach().cpu().numpy(),
    upper.detach().cpu().numpy(), color=line.get_color(), alpha=0.5
)

func.plot(test_x, scale(test_x), label='True latent function')
func.legend()

# sample from p(y|D,x) = \int p(y|f) p(f|D,x) df (doubly stochastic)
samp.scatter(train_x, train_y, alpha = 0.5, label='True train data')
samp.scatter(train_x, y_sim.cpu().detach().numpy(), alpha=0.5, label='Sample from the model')
samp.legend()
[9]:
<matplotlib.legend.Legend at 0x11e454fd0>
../../_images/examples_07_Pyro_Integration_Pyro_GPyTorch_Low_Level_14_1.png

Next steps

This example probably could’ve also been done (slightly easier) using the high-level Pyro integration, or using GPyTorch’s native SVGP implementation. The low-level interface really comes in handy when a gpytorch Likelihood is difficult to define. For an example of this, see the next example which uses the low-level interface to infer the intensity function of a Cox process.