Deep Gaussian Processes with Doubly Stochastic VI

In this notebook, we provide a GPyTorch implementation of deep Gaussian processes, where training and inference is performed using the method of Salimbeni et al., 2017 (https://arxiv.org/abs/1705.08933) adapted to CG-based inference.

We’ll be training a simple two layer deep GP on the elevators UCI dataset.

[32]:
%set_env CUDA_VISIBLE_DEVICES=0

import torch
import gpytorch
from torch.nn import Linear
from gpytorch.means import ConstantMean, LinearMean
from gpytorch.kernels import RBFKernel, ScaleKernel
from gpytorch.variational import VariationalStrategy, CholeskyVariationalDistribution
from gpytorch.distributions import MultivariateNormal
from gpytorch.models import ApproximateGP, GP
from gpytorch.mlls import VariationalELBO, AddedLossTerm
from gpytorch.likelihoods import GaussianLikelihood

env: CUDA_VISIBLE_DEVICES=0
[33]:
from gpytorch.models.deep_gps import AbstractDeepGPLayer, AbstractDeepGP, DeepLikelihood

Loading Data

For this example notebook, we’ll be using the elevators UCI dataset used in the paper. Running the next cell downloads a copy of the dataset that has already been scaled and normalized appropriately. For this notebook, we’ll simply be splitting the data using the first 80% of the data as training and the last 20% as testing.

Note: Running the next cell will attempt to download a ~400 KB dataset file to the current directory.

[34]:
import urllib.request
import os.path
from scipy.io import loadmat
from math import floor
import numpy as np

if not os.path.isfile('elevators.mat'):
    print('Downloading \'elevators\' UCI dataset...')
    urllib.request.urlretrieve('https://drive.google.com/uc?export=download&id=1jhWL3YUHvXIaftia4qeAyDwVxo6j1alk', 'elevators.mat')

data = torch.Tensor(loadmat('elevators.mat')['data'])
X = data[:, :-1]
y = data[:, -1]

N = data.shape[0]
np.random.seed(0)
data = data[np.random.permutation(np.arange(N)),:]

train_n = int(floor(0.8*len(X)))

train_x = X[:train_n, :].contiguous().cuda()
train_y = y[:train_n].contiguous().cuda()

test_x = X[train_n:, :].contiguous().cuda()
test_y = y[train_n:].contiguous().cuda()

mean = train_x.mean(dim=-2, keepdim=True)
std = train_x.std(dim=-2, keepdim=True) + 1e-6
train_x = (train_x - mean) / std
test_x = (test_x - mean) / std

mean,std = train_y.mean(),train_y.std()
train_y = (train_y - mean) / std
test_y = (test_y - mean) / std
[35]:
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(train_x, train_y)
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)

Defining GP layers

In GPyTorch, defining a GP involves extending one of our abstract GP models and defining a forward method that returns the prior. For deep GPs, things are similar, but there are two abstract GP models that must be overwritten: one for hidden layers and one for the deep GP model itself.

In the next cell, we define an example deep GP hidden layer. This looks very similar to every other variational GP you might define. However, there are a few key differences:

  1. Instead of extending ApproximateGP, we extend AbstractDeepGPLayer.
  2. AbstractDeepGPLayers need a number of input dimensions, a number of output dimensions, and a number of samples. This is kind of like a linear layer in a standard neural network – input_dims defines how many inputs this hidden layer will expect, and output_dims defines how many hidden GPs to create outputs for.
[37]:
class ToyDeepGPHiddenLayer(AbstractDeepGPLayer):
    def __init__(self, input_dims, output_dims, num_inducing=128, mean_type='constant'):
        if output_dims is None:
            inducing_points = torch.randn(num_inducing, input_dims)
        else:
            inducing_points = torch.randn(output_dims, num_inducing, input_dims)

        variational_distribution = CholeskyVariationalDistribution(
            num_inducing_points=num_inducing,
            batch_shape=torch.Size([output_dims]) if output_dims is not None else torch.Size([])
        )

        variational_strategy = VariationalStrategy(
            self,
            inducing_points,
            variational_distribution,
            learn_inducing_locations=True
        )

        super(ToyDeepGPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims)

        if mean_type == 'constant':
            self.mean_module = ConstantMean(batch_size=output_dims)
        else:
            self.mean_module = LinearMean(input_dims)
        self.covar_module = ScaleKernel(
            RBFKernel(batch_size=output_dims, ard_num_dims=input_dims),
            batch_size=output_dims, ard_num_dims=None
        )

        self.linear_layer = Linear(input_dims, 1)

    def forward(self, x):
        mean_x = self.mean_module(x) # self.linear_layer(x).squeeze(-1)
        covar_x = self.covar_module(x)
        return MultivariateNormal(mean_x, covar_x)

    def __call__(self, x, *other_inputs, **kwargs):
        """
        Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections
        easily. For example, hidden_layer2(hidden_layer1_outputs, inputs) will pass the concatenation of the first
        hidden layer's outputs and the input data to hidden_layer2.
        """
        if len(other_inputs):
            if isinstance(x, gpytorch.distributions.MultitaskMultivariateNormal):
                x = x.rsample()

            processed_inputs = [
                inp.unsqueeze(0).expand(self.num_samples, *inp.shape)
                for inp in other_inputs
            ]

            x = torch.cat([x] + processed_inputs, dim=-1)

        return super().__call__(x, are_samples=bool(len(other_inputs)))

Building the model

Now that we’ve defined a class for our hidden layers and a class for our output layer, we can build our deep GP. To do this, we create a Module whose forward is simply responsible for forwarding through the various layers.

This also allows for various network connectivities easily. For example calling,

hidden_rep2 = self.second_hidden_layer(hidden_rep1, inputs)

in forward would cause the second hidden layer to use both the output of the first hidden layer and the input data as inputs, concatenating the two together.

[38]:
class DeepGP(AbstractDeepGP):
    def __init__(self, train_x_shape):
        hidden_layer = ToyDeepGPHiddenLayer(
            input_dims=train_x_shape[-1],
            output_dims=10,
            mean_type='linear',
        )

        last_layer = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer.output_dims,
            output_dims=None,
            mean_type='constant',
        )

        super().__init__()

        self.hidden_layer = hidden_layer
        self.last_layer = last_layer
        self.likelihood = DeepLikelihood(GaussianLikelihood())

    def forward(self, inputs):
        hidden_rep1 = self.hidden_layer(inputs)
        output = self.last_layer(hidden_rep1)
        return output

    def predict(self, x):
        with gpytorch.settings.fast_computations(log_prob=False, solves=False), torch.no_grad():
            preds = self.likelihood.base_likelihood(self(x))
        predictive_means = preds.mean
        predictive_variances = preds.variance

        return predictive_means, predictive_variances
[39]:
model = DeepGP(train_x.shape).cuda()

Likelihood

Because deep GPs use some amounts of internal sampling (even in the stochastic variational setting), we need to handle the likelihood in a slightly different way. In the future, we anticipate DeepLikelihood being a general wrapper around an arbitrary likelihood once likelihoods become a little more general purpose, but for now we simply define a DeepGaussianLikelihood to use for regression.

Training the model

The training loop for a deep GP looks similar to a standard GP model with stochastic variational inference, but there are a few differences:

  1. Because the output of a deep GP is actually num_outputs x num_samples Gaussians rather than a single Gaussian, we need to expand the labels to be num_outputs x num_samples x minibatch_size before calling the ELBO.
  2. Because deep GPs involve a few added loss terms and normalize slightly differently, we created the VariationalELBO above with combine_terms=False. This just lets us do the extra normalization we need to make the math work out.
[40]:
num_epochs = 60

optimizer = torch.optim.Adam([
    {'params': model.parameters()},
], lr=0.01)
mll = VariationalELBO(model.likelihood, model, train_x.shape[-2])

import time

with gpytorch.settings.fast_computations(log_prob=False, solves=False):
    for i in range(num_epochs):
        for minibatch_i, (x_batch, y_batch) in enumerate(train_loader):
            start_time = time.time()
            optimizer.zero_grad()

            output = model(x_batch)
            loss = -mll(output, y_batch)
            print('Epoch %d [%d/%d] - Loss: %.3f - - Time: %.3f' % (i + 1, minibatch_i, len(train_loader), loss.item(), time.time() - start_time))

            loss.backward()
            optimizer.step()
Epoch 1 [0/13] - Loss: 1.951 - - Time: 0.028
Epoch 1 [1/13] - Loss: 1.972 - - Time: 0.023
Epoch 1 [2/13] - Loss: 1.966 - - Time: 0.026
Epoch 1 [3/13] - Loss: 1.903 - - Time: 0.023
Epoch 1 [4/13] - Loss: 2.000 - - Time: 0.022
Epoch 1 [5/13] - Loss: 1.915 - - Time: 0.022
Epoch 1 [6/13] - Loss: 1.879 - - Time: 0.022
Epoch 1 [7/13] - Loss: 1.878 - - Time: 0.022
Epoch 1 [8/13] - Loss: 1.897 - - Time: 0.022
Epoch 1 [9/13] - Loss: 1.824 - - Time: 0.022
Epoch 1 [10/13] - Loss: 1.853 - - Time: 0.023
Epoch 1 [11/13] - Loss: 1.850 - - Time: 0.022
Epoch 1 [12/13] - Loss: 1.919 - - Time: 0.021
Epoch 2 [0/13] - Loss: 1.859 - - Time: 0.022
Epoch 2 [1/13] - Loss: 1.827 - - Time: 0.022
Epoch 2 [2/13] - Loss: 1.844 - - Time: 0.022
Epoch 2 [3/13] - Loss: 1.848 - - Time: 0.023
Epoch 2 [4/13] - Loss: 1.786 - - Time: 0.022
Epoch 2 [5/13] - Loss: 1.747 - - Time: 0.022
Epoch 2 [6/13] - Loss: 1.849 - - Time: 0.023
Epoch 2 [7/13] - Loss: 1.775 - - Time: 0.023
Epoch 2 [8/13] - Loss: 1.838 - - Time: 0.022
Epoch 2 [9/13] - Loss: 1.887 - - Time: 0.024
Epoch 2 [10/13] - Loss: 1.772 - - Time: 0.023
Epoch 2 [11/13] - Loss: 1.824 - - Time: 0.022
Epoch 2 [12/13] - Loss: 1.758 - - Time: 0.023
Epoch 3 [0/13] - Loss: 1.756 - - Time: 0.023
Epoch 3 [1/13] - Loss: 1.788 - - Time: 0.022
Epoch 3 [2/13] - Loss: 1.764 - - Time: 0.023
Epoch 3 [3/13] - Loss: 1.775 - - Time: 0.022
Epoch 3 [4/13] - Loss: 1.743 - - Time: 0.023
Epoch 3 [5/13] - Loss: 1.719 - - Time: 0.024
Epoch 3 [6/13] - Loss: 1.778 - - Time: 0.022
Epoch 3 [7/13] - Loss: 1.761 - - Time: 0.022
Epoch 3 [8/13] - Loss: 1.741 - - Time: 0.023
Epoch 3 [9/13] - Loss: 1.733 - - Time: 0.022
Epoch 3 [10/13] - Loss: 1.695 - - Time: 0.022
Epoch 3 [11/13] - Loss: 1.723 - - Time: 0.023
Epoch 3 [12/13] - Loss: 1.719 - - Time: 0.022
Epoch 4 [0/13] - Loss: 1.694 - - Time: 0.023
Epoch 4 [1/13] - Loss: 1.705 - - Time: 0.023
Epoch 4 [2/13] - Loss: 1.698 - - Time: 0.027
Epoch 4 [3/13] - Loss: 1.729 - - Time: 0.023
Epoch 4 [4/13] - Loss: 1.681 - - Time: 0.023
Epoch 4 [5/13] - Loss: 1.678 - - Time: 0.023
Epoch 4 [6/13] - Loss: 1.707 - - Time: 0.022
Epoch 4 [7/13] - Loss: 1.693 - - Time: 0.023
Epoch 4 [8/13] - Loss: 1.639 - - Time: 0.022
Epoch 4 [9/13] - Loss: 1.723 - - Time: 0.022
Epoch 4 [10/13] - Loss: 1.654 - - Time: 0.024
Epoch 4 [11/13] - Loss: 1.656 - - Time: 0.023
Epoch 4 [12/13] - Loss: 1.681 - - Time: 0.022
Epoch 5 [0/13] - Loss: 1.648 - - Time: 0.022
Epoch 5 [1/13] - Loss: 1.601 - - Time: 0.021
Epoch 5 [2/13] - Loss: 1.678 - - Time: 0.022
Epoch 5 [3/13] - Loss: 1.644 - - Time: 0.023
Epoch 5 [4/13] - Loss: 1.666 - - Time: 0.021
Epoch 5 [5/13] - Loss: 1.655 - - Time: 0.022
Epoch 5 [6/13] - Loss: 1.606 - - Time: 0.022
Epoch 5 [7/13] - Loss: 1.598 - - Time: 0.023
Epoch 5 [8/13] - Loss: 1.611 - - Time: 0.023
Epoch 5 [9/13] - Loss: 1.631 - - Time: 0.023
Epoch 5 [10/13] - Loss: 1.594 - - Time: 0.024
Epoch 5 [11/13] - Loss: 1.637 - - Time: 0.022
Epoch 5 [12/13] - Loss: 1.611 - - Time: 0.022
Epoch 6 [0/13] - Loss: 1.602 - - Time: 0.024
Epoch 6 [1/13] - Loss: 1.575 - - Time: 0.023
Epoch 6 [2/13] - Loss: 1.607 - - Time: 0.022
Epoch 6 [3/13] - Loss: 1.539 - - Time: 0.024
Epoch 6 [4/13] - Loss: 1.560 - - Time: 0.023
Epoch 6 [5/13] - Loss: 1.543 - - Time: 0.022
Epoch 6 [6/13] - Loss: 1.514 - - Time: 0.023
Epoch 6 [7/13] - Loss: 1.504 - - Time: 0.023
Epoch 6 [8/13] - Loss: 1.582 - - Time: 0.022
Epoch 6 [9/13] - Loss: 1.604 - - Time: 0.023
Epoch 6 [10/13] - Loss: 1.526 - - Time: 0.022
Epoch 6 [11/13] - Loss: 1.553 - - Time: 0.022
Epoch 6 [12/13] - Loss: 1.526 - - Time: 0.023
Epoch 7 [0/13] - Loss: 1.522 - - Time: 0.022
Epoch 7 [1/13] - Loss: 1.492 - - Time: 0.028
Epoch 7 [2/13] - Loss: 1.523 - - Time: 0.023
Epoch 7 [3/13] - Loss: 1.514 - - Time: 0.022
Epoch 7 [4/13] - Loss: 1.474 - - Time: 0.023
Epoch 7 [5/13] - Loss: 1.521 - - Time: 0.023
Epoch 7 [6/13] - Loss: 1.466 - - Time: 0.022
Epoch 7 [7/13] - Loss: 1.444 - - Time: 0.022
Epoch 7 [8/13] - Loss: 1.434 - - Time: 0.023
Epoch 7 [9/13] - Loss: 1.391 - - Time: 0.022
Epoch 7 [10/13] - Loss: 1.433 - - Time: 0.022
Epoch 7 [11/13] - Loss: 1.447 - - Time: 0.023
Epoch 7 [12/13] - Loss: 1.479 - - Time: 0.022
Epoch 8 [0/13] - Loss: 1.404 - - Time: 0.022
Epoch 8 [1/13] - Loss: 1.401 - - Time: 0.023
Epoch 8 [2/13] - Loss: 1.419 - - Time: 0.022
Epoch 8 [3/13] - Loss: 1.379 - - Time: 0.022
Epoch 8 [4/13] - Loss: 1.368 - - Time: 0.023
Epoch 8 [5/13] - Loss: 1.369 - - Time: 0.023
Epoch 8 [6/13] - Loss: 1.365 - - Time: 0.022
Epoch 8 [7/13] - Loss: 1.387 - - Time: 0.022
Epoch 8 [8/13] - Loss: 1.328 - - Time: 0.022
Epoch 8 [9/13] - Loss: 1.341 - - Time: 0.022
Epoch 8 [10/13] - Loss: 1.332 - - Time: 0.023
Epoch 8 [11/13] - Loss: 1.402 - - Time: 0.022
Epoch 8 [12/13] - Loss: 1.343 - - Time: 0.022
Epoch 9 [0/13] - Loss: 1.312 - - Time: 0.023
Epoch 9 [1/13] - Loss: 1.309 - - Time: 0.022
Epoch 9 [2/13] - Loss: 1.270 - - Time: 0.022
Epoch 9 [3/13] - Loss: 1.290 - - Time: 0.026
Epoch 9 [4/13] - Loss: 1.245 - - Time: 0.023
Epoch 9 [5/13] - Loss: 1.266 - - Time: 0.022
Epoch 9 [6/13] - Loss: 1.227 - - Time: 0.023
Epoch 9 [7/13] - Loss: 1.265 - - Time: 0.022
Epoch 9 [8/13] - Loss: 1.252 - - Time: 0.022
Epoch 9 [9/13] - Loss: 1.235 - - Time: 0.023
Epoch 9 [10/13] - Loss: 1.234 - - Time: 0.022
Epoch 9 [11/13] - Loss: 1.210 - - Time: 0.022
Epoch 9 [12/13] - Loss: 1.203 - - Time: 0.023
Epoch 10 [0/13] - Loss: 1.170 - - Time: 0.022
Epoch 10 [1/13] - Loss: 1.197 - - Time: 0.023
Epoch 10 [2/13] - Loss: 1.179 - - Time: 0.023
Epoch 10 [3/13] - Loss: 1.188 - - Time: 0.022
Epoch 10 [4/13] - Loss: 1.155 - - Time: 0.023
Epoch 10 [5/13] - Loss: 1.172 - - Time: 0.023
Epoch 10 [6/13] - Loss: 1.135 - - Time: 0.022
Epoch 10 [7/13] - Loss: 1.173 - - Time: 0.022
Epoch 10 [8/13] - Loss: 1.127 - - Time: 0.024
Epoch 10 [9/13] - Loss: 1.127 - - Time: 0.023
Epoch 10 [10/13] - Loss: 1.126 - - Time: 0.022
Epoch 10 [11/13] - Loss: 1.116 - - Time: 0.023
Epoch 10 [12/13] - Loss: 1.115 - - Time: 0.022
Epoch 11 [0/13] - Loss: 1.112 - - Time: 0.022
Epoch 11 [1/13] - Loss: 1.124 - - Time: 0.023
Epoch 11 [2/13] - Loss: 1.093 - - Time: 0.022
Epoch 11 [3/13] - Loss: 1.106 - - Time: 0.022
Epoch 11 [4/13] - Loss: 1.068 - - Time: 0.023
Epoch 11 [5/13] - Loss: 1.062 - - Time: 0.022
Epoch 11 [6/13] - Loss: 1.080 - - Time: 0.022
Epoch 11 [7/13] - Loss: 1.081 - - Time: 0.023
Epoch 11 [8/13] - Loss: 1.074 - - Time: 0.023
Epoch 11 [9/13] - Loss: 1.055 - - Time: 0.022
Epoch 11 [10/13] - Loss: 1.071 - - Time: 0.022
Epoch 11 [11/13] - Loss: 1.065 - - Time: 0.022
Epoch 11 [12/13] - Loss: 1.068 - - Time: 0.022
Epoch 12 [0/13] - Loss: 1.035 - - Time: 0.022
Epoch 12 [1/13] - Loss: 1.077 - - Time: 0.022
Epoch 12 [2/13] - Loss: 1.042 - - Time: 0.022
Epoch 12 [3/13] - Loss: 1.023 - - Time: 0.022
Epoch 12 [4/13] - Loss: 1.033 - - Time: 0.023
Epoch 12 [5/13] - Loss: 1.024 - - Time: 0.022
Epoch 12 [6/13] - Loss: 1.018 - - Time: 0.022
Epoch 12 [7/13] - Loss: 1.008 - - Time: 0.023
Epoch 12 [8/13] - Loss: 1.013 - - Time: 0.022
Epoch 12 [9/13] - Loss: 1.011 - - Time: 0.022
Epoch 12 [10/13] - Loss: 1.021 - - Time: 0.023
Epoch 12 [11/13] - Loss: 1.035 - - Time: 0.022
Epoch 12 [12/13] - Loss: 1.015 - - Time: 0.022
Epoch 13 [0/13] - Loss: 1.006 - - Time: 0.023
Epoch 13 [1/13] - Loss: 0.999 - - Time: 0.022
Epoch 13 [2/13] - Loss: 0.999 - - Time: 0.022
Epoch 13 [3/13] - Loss: 1.001 - - Time: 0.023
Epoch 13 [4/13] - Loss: 0.974 - - Time: 0.023
Epoch 13 [5/13] - Loss: 0.978 - - Time: 0.023
Epoch 13 [6/13] - Loss: 0.969 - - Time: 0.023
Epoch 13 [7/13] - Loss: 0.959 - - Time: 0.022
Epoch 13 [8/13] - Loss: 1.000 - - Time: 0.022
Epoch 13 [9/13] - Loss: 0.967 - - Time: 0.023
Epoch 13 [10/13] - Loss: 0.963 - - Time: 0.035
Epoch 13 [11/13] - Loss: 0.953 - - Time: 0.023
Epoch 13 [12/13] - Loss: 0.962 - - Time: 0.023
Epoch 14 [0/13] - Loss: 0.945 - - Time: 0.023
Epoch 14 [1/13] - Loss: 0.944 - - Time: 0.021
Epoch 14 [2/13] - Loss: 0.963 - - Time: 0.022
Epoch 14 [3/13] - Loss: 0.938 - - Time: 0.022
Epoch 14 [4/13] - Loss: 0.941 - - Time: 0.022
Epoch 14 [5/13] - Loss: 0.947 - - Time: 0.022
Epoch 14 [6/13] - Loss: 0.927 - - Time: 0.023
Epoch 14 [7/13] - Loss: 0.920 - - Time: 0.023
Epoch 14 [8/13] - Loss: 0.932 - - Time: 0.023
Epoch 14 [9/13] - Loss: 0.915 - - Time: 0.023
Epoch 14 [10/13] - Loss: 0.911 - - Time: 0.022
Epoch 14 [11/13] - Loss: 0.904 - - Time: 0.022
Epoch 14 [12/13] - Loss: 0.898 - - Time: 0.024
Epoch 15 [0/13] - Loss: 0.897 - - Time: 0.022
Epoch 15 [1/13] - Loss: 0.891 - - Time: 0.023
Epoch 15 [2/13] - Loss: 0.906 - - Time: 0.022
Epoch 15 [3/13] - Loss: 0.886 - - Time: 0.022
Epoch 15 [4/13] - Loss: 0.899 - - Time: 0.024
Epoch 15 [5/13] - Loss: 0.881 - - Time: 0.023
Epoch 15 [6/13] - Loss: 0.884 - - Time: 0.023
Epoch 15 [7/13] - Loss: 0.874 - - Time: 0.023
Epoch 15 [8/13] - Loss: 0.888 - - Time: 0.022
Epoch 15 [9/13] - Loss: 0.863 - - Time: 0.022
Epoch 15 [10/13] - Loss: 0.867 - - Time: 0.023
Epoch 15 [11/13] - Loss: 0.852 - - Time: 0.022
Epoch 15 [12/13] - Loss: 0.861 - - Time: 0.022
Epoch 16 [0/13] - Loss: 0.863 - - Time: 0.023
Epoch 16 [1/13] - Loss: 0.843 - - Time: 0.023
Epoch 16 [2/13] - Loss: 0.852 - - Time: 0.023
Epoch 16 [3/13] - Loss: 0.847 - - Time: 0.022
Epoch 16 [4/13] - Loss: 0.841 - - Time: 0.021
Epoch 16 [5/13] - Loss: 0.836 - - Time: 0.022
Epoch 16 [6/13] - Loss: 0.849 - - Time: 0.022
Epoch 16 [7/13] - Loss: 0.834 - - Time: 0.022
Epoch 16 [8/13] - Loss: 0.827 - - Time: 0.022
Epoch 16 [9/13] - Loss: 0.819 - - Time: 0.021
Epoch 16 [10/13] - Loss: 0.815 - - Time: 0.021
Epoch 16 [11/13] - Loss: 0.832 - - Time: 0.022
Epoch 16 [12/13] - Loss: 0.812 - - Time: 0.023
Epoch 17 [0/13] - Loss: 0.823 - - Time: 0.023
Epoch 17 [1/13] - Loss: 0.817 - - Time: 0.023
Epoch 17 [2/13] - Loss: 0.803 - - Time: 0.023
Epoch 17 [3/13] - Loss: 0.798 - - Time: 0.022
Epoch 17 [4/13] - Loss: 0.803 - - Time: 0.023
Epoch 17 [5/13] - Loss: 0.792 - - Time: 0.023
Epoch 17 [6/13] - Loss: 0.785 - - Time: 0.023
Epoch 17 [7/13] - Loss: 0.778 - - Time: 0.023
Epoch 17 [8/13] - Loss: 0.784 - - Time: 0.023
Epoch 17 [9/13] - Loss: 0.766 - - Time: 0.023
Epoch 17 [10/13] - Loss: 0.772 - - Time: 0.023
Epoch 17 [11/13] - Loss: 0.781 - - Time: 0.022
Epoch 17 [12/13] - Loss: 0.784 - - Time: 0.023
Epoch 18 [0/13] - Loss: 0.774 - - Time: 0.023
Epoch 18 [1/13] - Loss: 0.774 - - Time: 0.023
Epoch 18 [2/13] - Loss: 0.768 - - Time: 0.022
Epoch 18 [3/13] - Loss: 0.757 - - Time: 0.022
Epoch 18 [4/13] - Loss: 0.748 - - Time: 0.022
Epoch 18 [5/13] - Loss: 0.744 - - Time: 0.023
Epoch 18 [6/13] - Loss: 0.767 - - Time: 0.024
Epoch 18 [7/13] - Loss: 0.750 - - Time: 0.023
Epoch 18 [8/13] - Loss: 0.733 - - Time: 0.023
Epoch 18 [9/13] - Loss: 0.742 - - Time: 0.024
Epoch 18 [10/13] - Loss: 0.728 - - Time: 0.023
Epoch 18 [11/13] - Loss: 0.728 - - Time: 0.023
Epoch 18 [12/13] - Loss: 0.728 - - Time: 0.024
Epoch 19 [0/13] - Loss: 0.733 - - Time: 0.023
Epoch 19 [1/13] - Loss: 0.709 - - Time: 0.023
Epoch 19 [2/13] - Loss: 0.727 - - Time: 0.023
Epoch 19 [3/13] - Loss: 0.703 - - Time: 0.023
Epoch 19 [4/13] - Loss: 0.717 - - Time: 0.023
Epoch 19 [5/13] - Loss: 0.731 - - Time: 0.024
Epoch 19 [6/13] - Loss: 0.704 - - Time: 0.024
Epoch 19 [7/13] - Loss: 0.694 - - Time: 0.023
Epoch 19 [8/13] - Loss: 0.700 - - Time: 0.024
Epoch 19 [9/13] - Loss: 0.698 - - Time: 0.022
Epoch 19 [10/13] - Loss: 0.693 - - Time: 0.023
Epoch 19 [11/13] - Loss: 0.684 - - Time: 0.023
Epoch 19 [12/13] - Loss: 0.686 - - Time: 0.023
Epoch 20 [0/13] - Loss: 0.681 - - Time: 0.023
Epoch 20 [1/13] - Loss: 0.668 - - Time: 0.023
Epoch 20 [2/13] - Loss: 0.696 - - Time: 0.023
Epoch 20 [3/13] - Loss: 0.694 - - Time: 0.023
Epoch 20 [4/13] - Loss: 0.663 - - Time: 0.023
Epoch 20 [5/13] - Loss: 0.656 - - Time: 0.023
Epoch 20 [6/13] - Loss: 0.675 - - Time: 0.022
Epoch 20 [7/13] - Loss: 0.687 - - Time: 0.023
Epoch 20 [8/13] - Loss: 0.653 - - Time: 0.023
Epoch 20 [9/13] - Loss: 0.675 - - Time: 0.023
Epoch 20 [10/13] - Loss: 0.681 - - Time: 0.023
Epoch 20 [11/13] - Loss: 0.649 - - Time: 0.023
Epoch 20 [12/13] - Loss: 0.656 - - Time: 0.022
Epoch 21 [0/13] - Loss: 0.639 - - Time: 0.023
Epoch 21 [1/13] - Loss: 0.670 - - Time: 0.022
Epoch 21 [2/13] - Loss: 0.647 - - Time: 0.022
Epoch 21 [3/13] - Loss: 0.636 - - Time: 0.023
Epoch 21 [4/13] - Loss: 0.636 - - Time: 0.023
Epoch 21 [5/13] - Loss: 0.613 - - Time: 0.023
Epoch 21 [6/13] - Loss: 0.646 - - Time: 0.022
Epoch 21 [7/13] - Loss: 0.671 - - Time: 0.022
Epoch 21 [8/13] - Loss: 0.623 - - Time: 0.022
Epoch 21 [9/13] - Loss: 0.622 - - Time: 0.023
Epoch 21 [10/13] - Loss: 0.634 - - Time: 0.022
Epoch 21 [11/13] - Loss: 0.627 - - Time: 0.022
Epoch 21 [12/13] - Loss: 0.581 - - Time: 0.023
Epoch 22 [0/13] - Loss: 0.601 - - Time: 0.023
Epoch 22 [1/13] - Loss: 0.597 - - Time: 0.023
Epoch 22 [2/13] - Loss: 0.608 - - Time: 0.023
Epoch 22 [3/13] - Loss: 0.632 - - Time: 0.022
Epoch 22 [4/13] - Loss: 0.621 - - Time: 0.022
Epoch 22 [5/13] - Loss: 0.598 - - Time: 0.023
Epoch 22 [6/13] - Loss: 0.635 - - Time: 0.022
Epoch 22 [7/13] - Loss: 0.619 - - Time: 0.022
Epoch 22 [8/13] - Loss: 0.597 - - Time: 0.023
Epoch 22 [9/13] - Loss: 0.573 - - Time: 0.022
Epoch 22 [10/13] - Loss: 0.588 - - Time: 0.022
Epoch 22 [11/13] - Loss: 0.578 - - Time: 0.023
Epoch 22 [12/13] - Loss: 0.575 - - Time: 0.022
Epoch 23 [0/13] - Loss: 0.590 - - Time: 0.023
Epoch 23 [1/13] - Loss: 0.595 - - Time: 0.023
Epoch 23 [2/13] - Loss: 0.561 - - Time: 0.022
Epoch 23 [3/13] - Loss: 0.598 - - Time: 0.022
Epoch 23 [4/13] - Loss: 0.572 - - Time: 0.022
Epoch 23 [5/13] - Loss: 0.566 - - Time: 0.023
Epoch 23 [6/13] - Loss: 0.598 - - Time: 0.022
Epoch 23 [7/13] - Loss: 0.579 - - Time: 0.023
Epoch 23 [8/13] - Loss: 0.562 - - Time: 0.022
Epoch 23 [9/13] - Loss: 0.593 - - Time: 0.022
Epoch 23 [10/13] - Loss: 0.557 - - Time: 0.023
Epoch 23 [11/13] - Loss: 0.556 - - Time: 0.023
Epoch 23 [12/13] - Loss: 0.528 - - Time: 0.023
Epoch 24 [0/13] - Loss: 0.550 - - Time: 0.023
Epoch 24 [1/13] - Loss: 0.553 - - Time: 0.022
Epoch 24 [2/13] - Loss: 0.547 - - Time: 0.022
Epoch 24 [3/13] - Loss: 0.538 - - Time: 0.023
Epoch 24 [4/13] - Loss: 0.580 - - Time: 0.022
Epoch 24 [5/13] - Loss: 0.564 - - Time: 0.022
Epoch 24 [6/13] - Loss: 0.546 - - Time: 0.023
Epoch 24 [7/13] - Loss: 0.545 - - Time: 0.022
Epoch 24 [8/13] - Loss: 0.536 - - Time: 0.023
Epoch 24 [9/13] - Loss: 0.542 - - Time: 0.023
Epoch 24 [10/13] - Loss: 0.513 - - Time: 0.022
Epoch 24 [11/13] - Loss: 0.559 - - Time: 0.023
Epoch 24 [12/13] - Loss: 0.548 - - Time: 0.027
Epoch 25 [0/13] - Loss: 0.526 - - Time: 0.022
Epoch 25 [1/13] - Loss: 0.551 - - Time: 0.023
Epoch 25 [2/13] - Loss: 0.501 - - Time: 0.023
Epoch 25 [3/13] - Loss: 0.500 - - Time: 0.022
Epoch 25 [4/13] - Loss: 0.540 - - Time: 0.023
Epoch 25 [5/13] - Loss: 0.535 - - Time: 0.023
Epoch 25 [6/13] - Loss: 0.535 - - Time: 0.022
Epoch 25 [7/13] - Loss: 0.526 - - Time: 0.023
Epoch 25 [8/13] - Loss: 0.503 - - Time: 0.023
Epoch 25 [9/13] - Loss: 0.508 - - Time: 0.022
Epoch 25 [10/13] - Loss: 0.559 - - Time: 0.023
Epoch 25 [11/13] - Loss: 0.546 - - Time: 0.022
Epoch 25 [12/13] - Loss: 0.536 - - Time: 0.022
Epoch 26 [0/13] - Loss: 0.546 - - Time: 0.023
Epoch 26 [1/13] - Loss: 0.505 - - Time: 0.023
Epoch 26 [2/13] - Loss: 0.474 - - Time: 0.022
Epoch 26 [3/13] - Loss: 0.540 - - Time: 0.023
Epoch 26 [4/13] - Loss: 0.511 - - Time: 0.022
Epoch 26 [5/13] - Loss: 0.513 - - Time: 0.022
Epoch 26 [6/13] - Loss: 0.484 - - Time: 0.023
Epoch 26 [7/13] - Loss: 0.487 - - Time: 0.023
Epoch 26 [8/13] - Loss: 0.512 - - Time: 0.022
Epoch 26 [9/13] - Loss: 0.474 - - Time: 0.023
Epoch 26 [10/13] - Loss: 0.547 - - Time: 0.022
Epoch 26 [11/13] - Loss: 0.512 - - Time: 0.022
Epoch 26 [12/13] - Loss: 0.510 - - Time: 0.023
Epoch 27 [0/13] - Loss: 0.495 - - Time: 0.022
Epoch 27 [1/13] - Loss: 0.501 - - Time: 0.022
Epoch 27 [2/13] - Loss: 0.466 - - Time: 0.023
Epoch 27 [3/13] - Loss: 0.486 - - Time: 0.022
Epoch 27 [4/13] - Loss: 0.496 - - Time: 0.022
Epoch 27 [5/13] - Loss: 0.552 - - Time: 0.023
Epoch 27 [6/13] - Loss: 0.519 - - Time: 0.023
Epoch 27 [7/13] - Loss: 0.511 - - Time: 0.023
Epoch 27 [8/13] - Loss: 0.492 - - Time: 0.023
Epoch 27 [9/13] - Loss: 0.467 - - Time: 0.022
Epoch 27 [10/13] - Loss: 0.482 - - Time: 0.022
Epoch 27 [11/13] - Loss: 0.477 - - Time: 0.023
Epoch 27 [12/13] - Loss: 0.500 - - Time: 0.022
Epoch 28 [0/13] - Loss: 0.464 - - Time: 0.022
Epoch 28 [1/13] - Loss: 0.504 - - Time: 0.023
Epoch 28 [2/13] - Loss: 0.504 - - Time: 0.022
Epoch 28 [3/13] - Loss: 0.459 - - Time: 0.022
Epoch 28 [4/13] - Loss: 0.469 - - Time: 0.023
Epoch 28 [5/13] - Loss: 0.508 - - Time: 0.023
Epoch 28 [6/13] - Loss: 0.500 - - Time: 0.022
Epoch 28 [7/13] - Loss: 0.486 - - Time: 0.023
Epoch 28 [8/13] - Loss: 0.477 - - Time: 0.022
Epoch 28 [9/13] - Loss: 0.472 - - Time: 0.022
Epoch 28 [10/13] - Loss: 0.474 - - Time: 0.023
Epoch 28 [11/13] - Loss: 0.473 - - Time: 0.022
Epoch 28 [12/13] - Loss: 0.455 - - Time: 0.022
Epoch 29 [0/13] - Loss: 0.501 - - Time: 0.023
Epoch 29 [1/13] - Loss: 0.469 - - Time: 0.022
Epoch 29 [2/13] - Loss: 0.488 - - Time: 0.022
Epoch 29 [3/13] - Loss: 0.498 - - Time: 0.023
Epoch 29 [4/13] - Loss: 0.486 - - Time: 0.022
Epoch 29 [5/13] - Loss: 0.451 - - Time: 0.022
Epoch 29 [6/13] - Loss: 0.471 - - Time: 0.023
Epoch 29 [7/13] - Loss: 0.433 - - Time: 0.023
Epoch 29 [8/13] - Loss: 0.483 - - Time: 0.022
Epoch 29 [9/13] - Loss: 0.479 - - Time: 0.023
Epoch 29 [10/13] - Loss: 0.465 - - Time: 0.022
Epoch 29 [11/13] - Loss: 0.495 - - Time: 0.022
Epoch 29 [12/13] - Loss: 0.443 - - Time: 0.023
Epoch 30 [0/13] - Loss: 0.447 - - Time: 0.022
Epoch 30 [1/13] - Loss: 0.448 - - Time: 0.022
Epoch 30 [2/13] - Loss: 0.462 - - Time: 0.023
Epoch 30 [3/13] - Loss: 0.493 - - Time: 0.022
Epoch 30 [4/13] - Loss: 0.483 - - Time: 0.022
Epoch 30 [5/13] - Loss: 0.450 - - Time: 0.023
Epoch 30 [6/13] - Loss: 0.465 - - Time: 0.022
Epoch 30 [7/13] - Loss: 0.485 - - Time: 0.022
Epoch 30 [8/13] - Loss: 0.441 - - Time: 0.023
Epoch 30 [9/13] - Loss: 0.453 - - Time: 0.022
Epoch 30 [10/13] - Loss: 0.453 - - Time: 0.022
Epoch 30 [11/13] - Loss: 0.496 - - Time: 0.023
Epoch 30 [12/13] - Loss: 0.519 - - Time: 0.022
Epoch 31 [0/13] - Loss: 0.435 - - Time: 0.022
Epoch 31 [1/13] - Loss: 0.401 - - Time: 0.022
Epoch 31 [2/13] - Loss: 0.450 - - Time: 0.022
Epoch 31 [3/13] - Loss: 0.469 - - Time: 0.022
Epoch 31 [4/13] - Loss: 0.466 - - Time: 0.023
Epoch 31 [5/13] - Loss: 0.444 - - Time: 0.022
Epoch 31 [6/13] - Loss: 0.498 - - Time: 0.022
Epoch 31 [7/13] - Loss: 0.480 - - Time: 0.023
Epoch 31 [8/13] - Loss: 0.511 - - Time: 0.022
Epoch 31 [9/13] - Loss: 0.456 - - Time: 0.023
Epoch 31 [10/13] - Loss: 0.495 - - Time: 0.023
Epoch 31 [11/13] - Loss: 0.449 - - Time: 0.023
Epoch 31 [12/13] - Loss: 0.468 - - Time: 0.023
Epoch 32 [0/13] - Loss: 0.439 - - Time: 0.023
Epoch 32 [1/13] - Loss: 0.460 - - Time: 0.023
Epoch 32 [2/13] - Loss: 0.446 - - Time: 0.023
Epoch 32 [3/13] - Loss: 0.487 - - Time: 0.023
Epoch 32 [4/13] - Loss: 0.465 - - Time: 0.022
Epoch 32 [5/13] - Loss: 0.487 - - Time: 0.023
Epoch 32 [6/13] - Loss: 0.463 - - Time: 0.023
Epoch 32 [7/13] - Loss: 0.439 - - Time: 0.023
Epoch 32 [8/13] - Loss: 0.481 - - Time: 0.022
Epoch 32 [9/13] - Loss: 0.474 - - Time: 0.023
Epoch 32 [10/13] - Loss: 0.452 - - Time: 0.022
Epoch 32 [11/13] - Loss: 0.477 - - Time: 0.022
Epoch 32 [12/13] - Loss: 0.405 - - Time: 0.023
Epoch 33 [0/13] - Loss: 0.435 - - Time: 0.023
Epoch 33 [1/13] - Loss: 0.474 - - Time: 0.022
Epoch 33 [2/13] - Loss: 0.414 - - Time: 0.023
Epoch 33 [3/13] - Loss: 0.461 - - Time: 0.022
Epoch 33 [4/13] - Loss: 0.484 - - Time: 0.022
Epoch 33 [5/13] - Loss: 0.465 - - Time: 0.023
Epoch 33 [6/13] - Loss: 0.423 - - Time: 0.023
Epoch 33 [7/13] - Loss: 0.468 - - Time: 0.022
Epoch 33 [8/13] - Loss: 0.462 - - Time: 0.023
Epoch 33 [9/13] - Loss: 0.461 - - Time: 0.023
Epoch 33 [10/13] - Loss: 0.419 - - Time: 0.022
Epoch 33 [11/13] - Loss: 0.468 - - Time: 0.023
Epoch 33 [12/13] - Loss: 0.482 - - Time: 0.022
Epoch 34 [0/13] - Loss: 0.421 - - Time: 0.022
Epoch 34 [1/13] - Loss: 0.444 - - Time: 0.023
Epoch 34 [2/13] - Loss: 0.460 - - Time: 0.022
Epoch 34 [3/13] - Loss: 0.458 - - Time: 0.023
Epoch 34 [4/13] - Loss: 0.473 - - Time: 0.023
Epoch 34 [5/13] - Loss: 0.460 - - Time: 0.022
Epoch 34 [6/13] - Loss: 0.458 - - Time: 0.023
Epoch 34 [7/13] - Loss: 0.467 - - Time: 0.022
Epoch 34 [8/13] - Loss: 0.427 - - Time: 0.022
Epoch 34 [9/13] - Loss: 0.423 - - Time: 0.023
Epoch 34 [10/13] - Loss: 0.456 - - Time: 0.022
Epoch 34 [11/13] - Loss: 0.403 - - Time: 0.022
Epoch 34 [12/13] - Loss: 0.479 - - Time: 0.023
Epoch 35 [0/13] - Loss: 0.433 - - Time: 0.023
Epoch 35 [1/13] - Loss: 0.426 - - Time: 0.022
Epoch 35 [2/13] - Loss: 0.430 - - Time: 0.023
Epoch 35 [3/13] - Loss: 0.467 - - Time: 0.022
Epoch 35 [4/13] - Loss: 0.405 - - Time: 0.022
Epoch 35 [5/13] - Loss: 0.408 - - Time: 0.023
Epoch 35 [6/13] - Loss: 0.474 - - Time: 0.023
Epoch 35 [7/13] - Loss: 0.468 - - Time: 0.022
Epoch 35 [8/13] - Loss: 0.429 - - Time: 0.023
Epoch 35 [9/13] - Loss: 0.474 - - Time: 0.022
Epoch 35 [10/13] - Loss: 0.428 - - Time: 0.022
Epoch 35 [11/13] - Loss: 0.438 - - Time: 0.023
Epoch 35 [12/13] - Loss: 0.473 - - Time: 0.022
Epoch 36 [0/13] - Loss: 0.420 - - Time: 0.022
Epoch 36 [1/13] - Loss: 0.445 - - Time: 0.023
Epoch 36 [2/13] - Loss: 0.449 - - Time: 0.022
Epoch 36 [3/13] - Loss: 0.461 - - Time: 0.022
Epoch 36 [4/13] - Loss: 0.432 - - Time: 0.023
Epoch 36 [5/13] - Loss: 0.496 - - Time: 0.022
Epoch 36 [6/13] - Loss: 0.413 - - Time: 0.022
Epoch 36 [7/13] - Loss: 0.464 - - Time: 0.023
Epoch 36 [8/13] - Loss: 0.443 - - Time: 0.022
Epoch 36 [9/13] - Loss: 0.463 - - Time: 0.022
Epoch 36 [10/13] - Loss: 0.403 - - Time: 0.023
Epoch 36 [11/13] - Loss: 0.461 - - Time: 0.023
Epoch 36 [12/13] - Loss: 0.401 - - Time: 0.023
Epoch 37 [0/13] - Loss: 0.400 - - Time: 0.023
Epoch 37 [1/13] - Loss: 0.442 - - Time: 0.023
Epoch 37 [2/13] - Loss: 0.462 - - Time: 0.023
Epoch 37 [3/13] - Loss: 0.439 - - Time: 0.024
Epoch 37 [4/13] - Loss: 0.447 - - Time: 0.023
Epoch 37 [5/13] - Loss: 0.418 - - Time: 0.028
Epoch 37 [6/13] - Loss: 0.487 - - Time: 0.024
Epoch 37 [7/13] - Loss: 0.436 - - Time: 0.022
Epoch 37 [8/13] - Loss: 0.397 - - Time: 0.022
Epoch 37 [9/13] - Loss: 0.451 - - Time: 0.023
Epoch 37 [10/13] - Loss: 0.484 - - Time: 0.022
Epoch 37 [11/13] - Loss: 0.448 - - Time: 0.022
Epoch 37 [12/13] - Loss: 0.444 - - Time: 0.023
Epoch 38 [0/13] - Loss: 0.441 - - Time: 0.022
Epoch 38 [1/13] - Loss: 0.446 - - Time: 0.022
Epoch 38 [2/13] - Loss: 0.425 - - Time: 0.023
Epoch 38 [3/13] - Loss: 0.461 - - Time: 0.022
Epoch 38 [4/13] - Loss: 0.430 - - Time: 0.022
Epoch 38 [5/13] - Loss: 0.433 - - Time: 0.023
Epoch 38 [6/13] - Loss: 0.472 - - Time: 0.022
Epoch 38 [7/13] - Loss: 0.474 - - Time: 0.023
Epoch 38 [8/13] - Loss: 0.390 - - Time: 0.023
Epoch 38 [9/13] - Loss: 0.373 - - Time: 0.022
Epoch 38 [10/13] - Loss: 0.461 - - Time: 0.022
Epoch 38 [11/13] - Loss: 0.452 - - Time: 0.023
Epoch 38 [12/13] - Loss: 0.469 - - Time: 0.022
Epoch 39 [0/13] - Loss: 0.473 - - Time: 0.022
Epoch 39 [1/13] - Loss: 0.427 - - Time: 0.023
Epoch 39 [2/13] - Loss: 0.435 - - Time: 0.022
Epoch 39 [3/13] - Loss: 0.405 - - Time: 0.022
Epoch 39 [4/13] - Loss: 0.419 - - Time: 0.024
Epoch 39 [5/13] - Loss: 0.429 - - Time: 0.022
Epoch 39 [6/13] - Loss: 0.426 - - Time: 0.022
Epoch 39 [7/13] - Loss: 0.485 - - Time: 0.023
Epoch 39 [8/13] - Loss: 0.440 - - Time: 0.022
Epoch 39 [9/13] - Loss: 0.459 - - Time: 0.022
Epoch 39 [10/13] - Loss: 0.474 - - Time: 0.023
Epoch 39 [11/13] - Loss: 0.410 - - Time: 0.022
Epoch 39 [12/13] - Loss: 0.421 - - Time: 0.023
Epoch 40 [0/13] - Loss: 0.478 - - Time: 0.023
Epoch 40 [1/13] - Loss: 0.379 - - Time: 0.022
Epoch 40 [2/13] - Loss: 0.473 - - Time: 0.022
Epoch 40 [3/13] - Loss: 0.428 - - Time: 0.023
Epoch 40 [4/13] - Loss: 0.493 - - Time: 0.022
Epoch 40 [5/13] - Loss: 0.403 - - Time: 0.022
Epoch 40 [6/13] - Loss: 0.425 - - Time: 0.023
Epoch 40 [7/13] - Loss: 0.434 - - Time: 0.022
Epoch 40 [8/13] - Loss: 0.451 - - Time: 0.022
Epoch 40 [9/13] - Loss: 0.417 - - Time: 0.022
Epoch 40 [10/13] - Loss: 0.430 - - Time: 0.023
Epoch 40 [11/13] - Loss: 0.439 - - Time: 0.022
Epoch 40 [12/13] - Loss: 0.431 - - Time: 0.022
Epoch 41 [0/13] - Loss: 0.451 - - Time: 0.022
Epoch 41 [1/13] - Loss: 0.430 - - Time: 0.022
Epoch 41 [2/13] - Loss: 0.405 - - Time: 0.023
Epoch 41 [3/13] - Loss: 0.411 - - Time: 0.022
Epoch 41 [4/13] - Loss: 0.405 - - Time: 0.023
Epoch 41 [5/13] - Loss: 0.478 - - Time: 0.023
Epoch 41 [6/13] - Loss: 0.439 - - Time: 0.023
Epoch 41 [7/13] - Loss: 0.400 - - Time: 0.023
Epoch 41 [8/13] - Loss: 0.455 - - Time: 0.023
Epoch 41 [9/13] - Loss: 0.435 - - Time: 0.023
Epoch 41 [10/13] - Loss: 0.468 - - Time: 0.023
Epoch 41 [11/13] - Loss: 0.429 - - Time: 0.024
Epoch 41 [12/13] - Loss: 0.436 - - Time: 0.023
Epoch 42 [0/13] - Loss: 0.403 - - Time: 0.023
Epoch 42 [1/13] - Loss: 0.416 - - Time: 0.023
Epoch 42 [2/13] - Loss: 0.450 - - Time: 0.023
Epoch 42 [3/13] - Loss: 0.444 - - Time: 0.023
Epoch 42 [4/13] - Loss: 0.413 - - Time: 0.024
Epoch 42 [5/13] - Loss: 0.435 - - Time: 0.023
Epoch 42 [6/13] - Loss: 0.426 - - Time: 0.022
Epoch 42 [7/13] - Loss: 0.440 - - Time: 0.023
Epoch 42 [8/13] - Loss: 0.411 - - Time: 0.023
Epoch 42 [9/13] - Loss: 0.418 - - Time: 0.024
Epoch 42 [10/13] - Loss: 0.505 - - Time: 0.023
Epoch 42 [11/13] - Loss: 0.468 - - Time: 0.023
Epoch 42 [12/13] - Loss: 0.442 - - Time: 0.023
Epoch 43 [0/13] - Loss: 0.363 - - Time: 0.023
Epoch 43 [1/13] - Loss: 0.414 - - Time: 0.022
Epoch 43 [2/13] - Loss: 0.403 - - Time: 0.022
Epoch 43 [3/13] - Loss: 0.469 - - Time: 0.023
Epoch 43 [4/13] - Loss: 0.453 - - Time: 0.023
Epoch 43 [5/13] - Loss: 0.432 - - Time: 0.023
Epoch 43 [6/13] - Loss: 0.450 - - Time: 0.023
Epoch 43 [7/13] - Loss: 0.487 - - Time: 0.027
Epoch 43 [8/13] - Loss: 0.418 - - Time: 0.022
Epoch 43 [9/13] - Loss: 0.418 - - Time: 0.023
Epoch 43 [10/13] - Loss: 0.448 - - Time: 0.023
Epoch 43 [11/13] - Loss: 0.451 - - Time: 0.023
Epoch 43 [12/13] - Loss: 0.513 - - Time: 0.023
Epoch 44 [0/13] - Loss: 0.430 - - Time: 0.022
Epoch 44 [1/13] - Loss: 0.400 - - Time: 0.022
Epoch 44 [2/13] - Loss: 0.444 - - Time: 0.022
Epoch 44 [3/13] - Loss: 0.454 - - Time: 0.022
Epoch 44 [4/13] - Loss: 0.460 - - Time: 0.023
Epoch 44 [5/13] - Loss: 0.452 - - Time: 0.024
Epoch 44 [6/13] - Loss: 0.432 - - Time: 0.023
Epoch 44 [7/13] - Loss: 0.429 - - Time: 0.022
Epoch 44 [8/13] - Loss: 0.406 - - Time: 0.023
Epoch 44 [9/13] - Loss: 0.449 - - Time: 0.023
Epoch 44 [10/13] - Loss: 0.446 - - Time: 0.023
Epoch 44 [11/13] - Loss: 0.446 - - Time: 0.023
Epoch 44 [12/13] - Loss: 0.439 - - Time: 0.022
Epoch 45 [0/13] - Loss: 0.398 - - Time: 0.022
Epoch 45 [1/13] - Loss: 0.401 - - Time: 0.023
Epoch 45 [2/13] - Loss: 0.445 - - Time: 0.023
Epoch 45 [3/13] - Loss: 0.432 - - Time: 0.022
Epoch 45 [4/13] - Loss: 0.444 - - Time: 0.023
Epoch 45 [5/13] - Loss: 0.439 - - Time: 0.023
Epoch 45 [6/13] - Loss: 0.415 - - Time: 0.029
Epoch 45 [7/13] - Loss: 0.405 - - Time: 0.023
Epoch 45 [8/13] - Loss: 0.512 - - Time: 0.023
Epoch 45 [9/13] - Loss: 0.425 - - Time: 0.022
Epoch 45 [10/13] - Loss: 0.372 - - Time: 0.023
Epoch 45 [11/13] - Loss: 0.464 - - Time: 0.022
Epoch 45 [12/13] - Loss: 0.423 - - Time: 0.022
Epoch 46 [0/13] - Loss: 0.410 - - Time: 0.023
Epoch 46 [1/13] - Loss: 0.408 - - Time: 0.022
Epoch 46 [2/13] - Loss: 0.435 - - Time: 0.022
Epoch 46 [3/13] - Loss: 0.453 - - Time: 0.023
Epoch 46 [4/13] - Loss: 0.465 - - Time: 0.022
Epoch 46 [5/13] - Loss: 0.389 - - Time: 0.022
Epoch 46 [6/13] - Loss: 0.438 - - Time: 0.022
Epoch 46 [7/13] - Loss: 0.437 - - Time: 0.023
Epoch 46 [8/13] - Loss: 0.393 - - Time: 0.022
Epoch 46 [9/13] - Loss: 0.420 - - Time: 0.023
Epoch 46 [10/13] - Loss: 0.445 - - Time: 0.022
Epoch 46 [11/13] - Loss: 0.440 - - Time: 0.022
Epoch 46 [12/13] - Loss: 0.435 - - Time: 0.023
Epoch 47 [0/13] - Loss: 0.426 - - Time: 0.023
Epoch 47 [1/13] - Loss: 0.438 - - Time: 0.022
Epoch 47 [2/13] - Loss: 0.426 - - Time: 0.023
Epoch 47 [3/13] - Loss: 0.426 - - Time: 0.023
Epoch 47 [4/13] - Loss: 0.447 - - Time: 0.022
Epoch 47 [5/13] - Loss: 0.418 - - Time: 0.022
Epoch 47 [6/13] - Loss: 0.421 - - Time: 0.022
Epoch 47 [7/13] - Loss: 0.429 - - Time: 0.022
Epoch 47 [8/13] - Loss: 0.430 - - Time: 0.023
Epoch 47 [9/13] - Loss: 0.439 - - Time: 0.022
Epoch 47 [10/13] - Loss: 0.421 - - Time: 0.022
Epoch 47 [11/13] - Loss: 0.464 - - Time: 0.023
Epoch 47 [12/13] - Loss: 0.411 - - Time: 0.022
Epoch 48 [0/13] - Loss: 0.409 - - Time: 0.023
Epoch 48 [1/13] - Loss: 0.385 - - Time: 0.023
Epoch 48 [2/13] - Loss: 0.407 - - Time: 0.022
Epoch 48 [3/13] - Loss: 0.425 - - Time: 0.022
Epoch 48 [4/13] - Loss: 0.411 - - Time: 0.023
Epoch 48 [5/13] - Loss: 0.463 - - Time: 0.022
Epoch 48 [6/13] - Loss: 0.472 - - Time: 0.022
Epoch 48 [7/13] - Loss: 0.464 - - Time: 0.023
Epoch 48 [8/13] - Loss: 0.469 - - Time: 0.022
Epoch 48 [9/13] - Loss: 0.399 - - Time: 0.023
Epoch 48 [10/13] - Loss: 0.408 - - Time: 0.023
Epoch 48 [11/13] - Loss: 0.436 - - Time: 0.022
Epoch 48 [12/13] - Loss: 0.427 - - Time: 0.022
Epoch 49 [0/13] - Loss: 0.404 - - Time: 0.023
Epoch 49 [1/13] - Loss: 0.419 - - Time: 0.022
Epoch 49 [2/13] - Loss: 0.420 - - Time: 0.022
Epoch 49 [3/13] - Loss: 0.460 - - Time: 0.023
Epoch 49 [4/13] - Loss: 0.432 - - Time: 0.023
Epoch 49 [5/13] - Loss: 0.399 - - Time: 0.022
Epoch 49 [6/13] - Loss: 0.475 - - Time: 0.023
Epoch 49 [7/13] - Loss: 0.415 - - Time: 0.022
Epoch 49 [8/13] - Loss: 0.428 - - Time: 0.022
Epoch 49 [9/13] - Loss: 0.439 - - Time: 0.028
Epoch 49 [10/13] - Loss: 0.439 - - Time: 0.022
Epoch 49 [11/13] - Loss: 0.383 - - Time: 0.022
Epoch 49 [12/13] - Loss: 0.409 - - Time: 0.023
Epoch 50 [0/13] - Loss: 0.423 - - Time: 0.022
Epoch 50 [1/13] - Loss: 0.440 - - Time: 0.022
Epoch 50 [2/13] - Loss: 0.402 - - Time: 0.023
Epoch 50 [3/13] - Loss: 0.496 - - Time: 0.022
Epoch 50 [4/13] - Loss: 0.446 - - Time: 0.023
Epoch 50 [5/13] - Loss: 0.405 - - Time: 0.022
Epoch 50 [6/13] - Loss: 0.406 - - Time: 0.022
Epoch 50 [7/13] - Loss: 0.410 - - Time: 0.022
Epoch 50 [8/13] - Loss: 0.390 - - Time: 0.022
Epoch 50 [9/13] - Loss: 0.450 - - Time: 0.022
Epoch 50 [10/13] - Loss: 0.426 - - Time: 0.022
Epoch 50 [11/13] - Loss: 0.424 - - Time: 0.023
Epoch 50 [12/13] - Loss: 0.449 - - Time: 0.022
Epoch 51 [0/13] - Loss: 0.452 - - Time: 0.022
Epoch 51 [1/13] - Loss: 0.427 - - Time: 0.023
Epoch 51 [2/13] - Loss: 0.418 - - Time: 0.023
Epoch 51 [3/13] - Loss: 0.466 - - Time: 0.022
Epoch 51 [4/13] - Loss: 0.425 - - Time: 0.022
Epoch 51 [5/13] - Loss: 0.474 - - Time: 0.023
Epoch 51 [6/13] - Loss: 0.442 - - Time: 0.022
Epoch 51 [7/13] - Loss: 0.430 - - Time: 0.022
Epoch 51 [8/13] - Loss: 0.402 - - Time: 0.023
Epoch 51 [9/13] - Loss: 0.463 - - Time: 0.022
Epoch 51 [10/13] - Loss: 0.390 - - Time: 0.022
Epoch 51 [11/13] - Loss: 0.422 - - Time: 0.023
Epoch 51 [12/13] - Loss: 0.387 - - Time: 0.022
Epoch 52 [0/13] - Loss: 0.411 - - Time: 0.022
Epoch 52 [1/13] - Loss: 0.426 - - Time: 0.023
Epoch 52 [2/13] - Loss: 0.410 - - Time: 0.022
Epoch 52 [3/13] - Loss: 0.428 - - Time: 0.022
Epoch 52 [4/13] - Loss: 0.508 - - Time: 0.023
Epoch 52 [5/13] - Loss: 0.425 - - Time: 0.022
Epoch 52 [6/13] - Loss: 0.431 - - Time: 0.022
Epoch 52 [7/13] - Loss: 0.453 - - Time: 0.023
Epoch 52 [8/13] - Loss: 0.425 - - Time: 0.022
Epoch 52 [9/13] - Loss: 0.411 - - Time: 0.022
Epoch 52 [10/13] - Loss: 0.445 - - Time: 0.023
Epoch 52 [11/13] - Loss: 0.405 - - Time: 0.022
Epoch 52 [12/13] - Loss: 0.475 - - Time: 0.022
Epoch 53 [0/13] - Loss: 0.407 - - Time: 0.023
Epoch 53 [1/13] - Loss: 0.384 - - Time: 0.022
Epoch 53 [2/13] - Loss: 0.467 - - Time: 0.023
Epoch 53 [3/13] - Loss: 0.419 - - Time: 0.023
Epoch 53 [4/13] - Loss: 0.384 - - Time: 0.022
Epoch 53 [5/13] - Loss: 0.494 - - Time: 0.022
Epoch 53 [6/13] - Loss: 0.440 - - Time: 0.023
Epoch 53 [7/13] - Loss: 0.398 - - Time: 0.022
Epoch 53 [8/13] - Loss: 0.490 - - Time: 0.022
Epoch 53 [9/13] - Loss: 0.476 - - Time: 0.023
Epoch 53 [10/13] - Loss: 0.441 - - Time: 0.022
Epoch 53 [11/13] - Loss: 0.490 - - Time: 0.022
Epoch 53 [12/13] - Loss: 0.397 - - Time: 0.022
Epoch 54 [0/13] - Loss: 0.426 - - Time: 0.022
Epoch 54 [1/13] - Loss: 0.375 - - Time: 0.023
Epoch 54 [2/13] - Loss: 0.463 - - Time: 0.023
Epoch 54 [3/13] - Loss: 0.472 - - Time: 0.022
Epoch 54 [4/13] - Loss: 0.432 - - Time: 0.022
Epoch 54 [5/13] - Loss: 0.405 - - Time: 0.022
Epoch 54 [6/13] - Loss: 0.442 - - Time: 0.022
Epoch 54 [7/13] - Loss: 0.465 - - Time: 0.023
Epoch 54 [8/13] - Loss: 0.406 - - Time: 0.023
Epoch 54 [9/13] - Loss: 0.401 - - Time: 0.023
Epoch 54 [10/13] - Loss: 0.498 - - Time: 0.022
Epoch 54 [11/13] - Loss: 0.426 - - Time: 0.022
Epoch 54 [12/13] - Loss: 0.423 - - Time: 0.023
Epoch 55 [0/13] - Loss: 0.412 - - Time: 0.022
Epoch 55 [1/13] - Loss: 0.428 - - Time: 0.022
Epoch 55 [2/13] - Loss: 0.421 - - Time: 0.023
Epoch 55 [3/13] - Loss: 0.425 - - Time: 0.022
Epoch 55 [4/13] - Loss: 0.427 - - Time: 0.022
Epoch 55 [5/13] - Loss: 0.433 - - Time: 0.023
Epoch 55 [6/13] - Loss: 0.414 - - Time: 0.023
Epoch 55 [7/13] - Loss: 0.398 - - Time: 0.022
Epoch 55 [8/13] - Loss: 0.443 - - Time: 0.023
Epoch 55 [9/13] - Loss: 0.425 - - Time: 0.022
Epoch 55 [10/13] - Loss: 0.458 - - Time: 0.023
Epoch 55 [11/13] - Loss: 0.416 - - Time: 0.023
Epoch 55 [12/13] - Loss: 0.447 - - Time: 0.022
Epoch 56 [0/13] - Loss: 0.372 - - Time: 0.022
Epoch 56 [1/13] - Loss: 0.454 - - Time: 0.023
Epoch 56 [2/13] - Loss: 0.432 - - Time: 0.022
Epoch 56 [3/13] - Loss: 0.452 - - Time: 0.023
Epoch 56 [4/13] - Loss: 0.434 - - Time: 0.023
Epoch 56 [5/13] - Loss: 0.391 - - Time: 0.022
Epoch 56 [6/13] - Loss: 0.367 - - Time: 0.022
Epoch 56 [7/13] - Loss: 0.452 - - Time: 0.023
Epoch 56 [8/13] - Loss: 0.441 - - Time: 0.022
Epoch 56 [9/13] - Loss: 0.461 - - Time: 0.022
Epoch 56 [10/13] - Loss: 0.408 - - Time: 0.023
Epoch 56 [11/13] - Loss: 0.433 - - Time: 0.022
Epoch 56 [12/13] - Loss: 0.425 - - Time: 0.022
Epoch 57 [0/13] - Loss: 0.445 - - Time: 0.022
Epoch 57 [1/13] - Loss: 0.439 - - Time: 0.022
Epoch 57 [2/13] - Loss: 0.425 - - Time: 0.023
Epoch 57 [3/13] - Loss: 0.408 - - Time: 0.023
Epoch 57 [4/13] - Loss: 0.378 - - Time: 0.023
Epoch 57 [5/13] - Loss: 0.414 - - Time: 0.022
Epoch 57 [6/13] - Loss: 0.426 - - Time: 0.024
Epoch 57 [7/13] - Loss: 0.449 - - Time: 0.022
Epoch 57 [8/13] - Loss: 0.437 - - Time: 0.022
Epoch 57 [9/13] - Loss: 0.384 - - Time: 0.023
Epoch 57 [10/13] - Loss: 0.411 - - Time: 0.023
Epoch 57 [11/13] - Loss: 0.410 - - Time: 0.023
Epoch 57 [12/13] - Loss: 0.437 - - Time: 0.023
Epoch 58 [0/13] - Loss: 0.460 - - Time: 0.022
Epoch 58 [1/13] - Loss: 0.436 - - Time: 0.023
Epoch 58 [2/13] - Loss: 0.426 - - Time: 0.023
Epoch 58 [3/13] - Loss: 0.486 - - Time: 0.023
Epoch 58 [4/13] - Loss: 0.384 - - Time: 0.022
Epoch 58 [5/13] - Loss: 0.439 - - Time: 0.023
Epoch 58 [6/13] - Loss: 0.405 - - Time: 0.023
Epoch 58 [7/13] - Loss: 0.432 - - Time: 0.023
Epoch 58 [8/13] - Loss: 0.390 - - Time: 0.023
Epoch 58 [9/13] - Loss: 0.428 - - Time: 0.022
Epoch 58 [10/13] - Loss: 0.375 - - Time: 0.022
Epoch 58 [11/13] - Loss: 0.389 - - Time: 0.023
Epoch 58 [12/13] - Loss: 0.423 - - Time: 0.023
Epoch 59 [0/13] - Loss: 0.373 - - Time: 0.023
Epoch 59 [1/13] - Loss: 0.446 - - Time: 0.023
Epoch 59 [2/13] - Loss: 0.389 - - Time: 0.023
Epoch 59 [3/13] - Loss: 0.419 - - Time: 0.022
Epoch 59 [4/13] - Loss: 0.459 - - Time: 0.024
Epoch 59 [5/13] - Loss: 0.391 - - Time: 0.023
Epoch 59 [6/13] - Loss: 0.335 - - Time: 0.022
Epoch 59 [7/13] - Loss: 0.456 - - Time: 0.023
Epoch 59 [8/13] - Loss: 0.445 - - Time: 0.022
Epoch 59 [9/13] - Loss: 0.446 - - Time: 0.022
Epoch 59 [10/13] - Loss: 0.421 - - Time: 0.023
Epoch 59 [11/13] - Loss: 0.420 - - Time: 0.022
Epoch 59 [12/13] - Loss: 0.443 - - Time: 0.023
Epoch 60 [0/13] - Loss: 0.433 - - Time: 0.023
Epoch 60 [1/13] - Loss: 0.417 - - Time: 0.023
Epoch 60 [2/13] - Loss: 0.432 - - Time: 0.022
Epoch 60 [3/13] - Loss: 0.393 - - Time: 0.022
Epoch 60 [4/13] - Loss: 0.446 - - Time: 0.022
Epoch 60 [5/13] - Loss: 0.443 - - Time: 0.022
Epoch 60 [6/13] - Loss: 0.415 - - Time: 0.023
Epoch 60 [7/13] - Loss: 0.475 - - Time: 0.022
Epoch 60 [8/13] - Loss: 0.434 - - Time: 0.022
Epoch 60 [9/13] - Loss: 0.409 - - Time: 0.023
Epoch 60 [10/13] - Loss: 0.371 - - Time: 0.022
Epoch 60 [11/13] - Loss: 0.373 - - Time: 0.022
Epoch 60 [12/13] - Loss: 0.418 - - Time: 0.023

Make predictions and get an RMSE

The output distribution of a deep GP in this framework is actually a mixture of num_samples Gaussians for each output. We get predictions the same way with all GPyTorch models, but we do currently need to do some reshaping to get the means and variances in a reasonable form.

SVGP gets an RMSE of around 0.41 after 60 epochs of training, so overall getting an RMSE of 0.35 out of a 2 layer deep GP without much tuning involved is pretty good!

[43]:
import gpytorch
import math

model.eval()
predictive_means, predictive_variances = model.predict(test_x)

rmse = torch.mean(torch.pow(predictive_means.mean(0) - test_y, 2)).sqrt()
with torch.no_grad():
    test_ll = torch.distributions.Normal(predictive_means, predictive_variances.sqrt()).log_prob(test_y).logsumexp(dim=0) - math.log(predictive_means.size(0))

print(rmse)
print(test_ll.mean())
tensor(0.3528, device='cuda:0')
tensor(-0.3607, device='cuda:0')
[ ]: