gpytorch.likelihoods

Likelihood

class gpytorch.likelihoods.Likelihood(max_plate_nesting=1)[source]

A Likelihood in GPyTorch specifies the mapping from latent function values \(f(\mathbf X)\) to observed labels \(y\).

For example, in the case of regression this might be a Gaussian distribution, as \(y(\mathbf x)\) is equal to \(f(\mathbf x)\) plus Gaussian noise:

\[y(\mathbf x) = f(\mathbf x) + \epsilon, \:\:\:\: \epsilon \sim N(0,\sigma^{2}_{n} \mathbf I)\]

In the case of classification, this might be a Bernoulli distribution, where the probability that \(y=1\) is given by the latent function passed through some sigmoid or probit function:

\[\begin{split}y(\mathbf x) = \begin{cases} 1 & \text{w/ probability} \:\: \sigma(f(\mathbf x)) \\ 0 & \text{w/ probability} \:\: 1-\sigma(f(\mathbf x)) \end{cases}\end{split}\]

In either case, to implement a likelihood function, GPyTorch only requires a forward method that computes the conditional distribution \(p(y \mid f(\mathbf x))\).

Parameters:
  • has_analytic_marginal (bool) – Whether or not the marginal distribution \(p(\mathbf y)\) can be computed in closed form. (See __call__() docstring.)

  • max_plate_nesting (int) – (For Pyro integration only.) How many batch dimensions are in the function. This should be modified if the likelihood uses plated random variables. (Default = 1) This should be modified if the likelihood uses plated random variables. (Default = 1)

  • name_prefix (str) – (For Pyro integration only.) Prefix to assign to named Pyro latent variables.

  • num_data (int) – (For Pyro integration only.) Total amount of observations.

__call__(input, *args, **kwargs)[source]

Calling this object does one of two things:

  1. If likelihood is called with a torch.Tensor object, then it is assumed that the input is samples from \(f(\mathbf x)\). This returns the conditional distribution \(p(y|f(\mathbf x))\).

f = torch.randn(20)
likelihood = gpytorch.likelihoods.GaussianLikelihood()
conditional = likelihood(f)
print(type(conditional), conditional.batch_shape, conditional.event_shape)
# >>> torch.distributions.Normal, torch.Size([20]), torch.Size([])
  1. If likelihood is called with a MultivariateNormal object, then it is assumed that the input is the distribution \(f(\mathbf x)\). This returns the marginal distribution \(p(y|\mathbf x)\).

    The form of the marginal distribution depends on the likelihood. For BernoulliLikelihood and GaussianLikelihood objects, the marginal distribution can be computed analytically, and the likelihood returns the analytic distribution. For most other likelihoods, there is no analytic form for the marginal, and so the likelihood instead returns a batch of Monte Carlo samples from the marginal.

mean = torch.randn(20)
covar = linear_operator.operators.DiagLinearOperator(torch.ones(20))
f = gpytorch.distributions.MultivariateNormal(mean, covar)

# Analytic marginal computation - Bernoulli and Gaussian likelihoods only
analytic_marginal_likelihood = gpytorch.likelihoods.GaussianLikelihood()
marginal = analytic_marginal_likeihood(f)
print(type(marginal), marginal.batch_shape, marginal.event_shape)
# >>> gpytorch.distributions.MultivariateNormal, torch.Size([]), torch.Size([20])

# MC marginal computation - all other likelihoods
mc_marginal_likelihood = gpytorch.likelihoods.BetaLikelihood()
with gpytorch.settings.num_likelihood_samples(15):
    marginal = analytic_marginal_likeihood(f)
print(type(marginal), marginal.batch_shape, marginal.event_shape)
# >>> torch.distributions.Beta, torch.Size([15, 20]), torch.Size([])
# (The batch_shape of torch.Size([15, 20]) represents 15 MC samples for 20 data points.

Note

If a Likelihood supports analytic marginals, the has_analytic_marginal property will be True. If a Likelihood does not support analytic marginals, you can set the number of Monte Carlo samples using the gpytorch.settings.num_likelihood_samples context manager.

Parameters:
  • input (torch.Tensor or MultivariateNormal) – Either a (… x N) sample from \(\mathbf f\) or a (… x N) MVN distribution of \(\mathbf f\).

  • args – Additional args (passed to the foward function).

  • kwargs – Additional kwargs (passed to the foward function).

Return type:

torch.distributions.distribution.Distribution

Returns:

Either a conditional \(p(\mathbf y \mid \mathbf f)\) or marginal \(p(\mathbf y)\) based on whether input is a Tensor or a MultivariateNormal (see above).

expected_log_prob(observations, function_dist, *args, **kwargs)[source]

(Used by VariationalELBO for variational inference.)

Computes the expected log likelihood, where the expectation is over the GP variational distribution.

\[\sum_{\mathbf x, y} \mathbb{E}_{q\left( f(\mathbf x) \right)} \left[ \log p \left( y \mid f(\mathbf x) \right) \right]\]
Parameters:
  • observations (torch.Tensor) – Values of \(y\).

  • function_dist (MultivariateNormal) – Distribution for \(f(x)\).

  • args – Additional args (passed to the foward function).

  • kwargs – Additional kwargs (passed to the foward function).

Return type:

torch.Tensor

abstract forward(function_samples, *args, data={}, **kwargs)[source]

Computes the conditional distribution \(p(\mathbf y \mid \mathbf f, \ldots)\) that defines the likelihood.

Parameters:
  • function_samples (torch.Tensor) – Samples from the function (\(\mathbf f\))

  • data (Dict[str, torch.Tensor]) – (Pyro integration only.) Additional variables that the likelihood needs to condition on. The keys of the dictionary will correspond to Pyro sample sites in the likelihood’s model/guide.

  • args – Additional args

  • kwargs – Additional kwargs

Return type:

torch.distributions.distribution.Distribution

log_marginal(observations, function_dist, *args, **kwargs)[source]

(Used by PredictiveLogLikelihood for approximate inference.)

Computes the log marginal likelihood of the approximate predictive distribution

\[\sum_{\mathbf x, y} \log \mathbb{E}_{q\left( f(\mathbf x) \right)} \left[ p \left( y \mid f(\mathbf x) \right) \right]\]

Note that this differs from expected_log_prob() because the \(log\) is on the outside of the expectation.

Parameters:
  • observations (torch.Tensor) – Values of \(y\).

  • function_dist (MultivariateNormal) – Distribution for \(f(x)\).

  • args – Additional args (passed to the foward function).

  • kwargs – Additional kwargs (passed to the foward function).

Return type:

torch.Tensor

marginal(function_dist, *args, **kwargs)[source]

Computes a predictive distribution \(p(y^* | \mathbf x^*)\) given either a posterior distribution \(p(\mathbf f | \mathcal D, \mathbf x)\) or a prior distribution \(p(\mathbf f|\mathbf x)\) as input.

With both exact inference and variational inference, the form of \(p(\mathbf f|\mathcal D, \mathbf x)\) or \(p(\mathbf f| \mathbf x)\) should usually be Gaussian. As a result, function_dist should usually be a MultivariateNormal specified by the mean and (co)variance of \(p(\mathbf f|...)\).

Parameters:
  • function_dist (MultivariateNormal) – Distribution for \(f(x)\).

  • args – Additional args (passed to the foward function).

  • kwargs – Additional kwargs (passed to the foward function).

Return type:

torch.distributions.distribution.Distribution

Returns:

The marginal distribution, or samples from it.

pyro_guide(function_dist, target, *args, **kwargs)[source]

(For Pyro integration only).

Part of the guide function for the likelihood. This should be re-defined if the likelihood contains any latent variables that need to be infered.

Parameters:
  • function_dist (MultivariateNormal) – Distribution of latent function \(q(\mathbf f)\).

  • target (torch.Tensor) – Observed \(\mathbf y\).

  • args – Additional args (passed to the foward function).

  • kwargs – Additional kwargs (passed to the foward function).

Return type:

NoneType

pyro_model(function_dist, target, *args, **kwargs)[source]

(For Pyro integration only).

Part of the model function for the likelihood. It should return the This should be re-defined if the likelihood contains any latent variables that need to be infered.

Parameters:
  • function_dist (MultivariateNormal) – Distribution of latent function \(p(\mathbf f)\).

  • target (torch.Tensor) – Observed \(\mathbf y\).

  • args – Additional args (passed to the foward function).

  • kwargs – Additional kwargs (passed to the foward function).

Return type:

torch.Tensor

One-Dimensional Likelihoods

Likelihoods for GPs that are distributions of scalar functions. (I.e. for a specific \(\mathbf x\) we expect that \(f(\mathbf x) \in \mathbb{R}\).)

One-dimensional likelihoods should extend gpytoch.likelihoods._OneDimensionalLikelihood to reduce the variance when computing approximate GP objective functions. (Variance reduction is accomplished by using 1D Gauss-Hermite quadrature rather than MC-integration).

GaussianLikelihood

class gpytorch.likelihoods.GaussianLikelihood(noise_prior=None, noise_constraint=None, batch_shape=torch.Size([]), **kwargs)[source]

The standard likelihood for regression. Assumes a standard homoskedastic noise model:

\[p(y \mid f) = f + \epsilon, \quad \epsilon \sim \mathcal N (0, \sigma^2)\]

where \(\sigma^2\) is a noise parameter.

Note

This likelihood can be used for exact or approximate inference.

Note

GaussianLikelihood has an analytic marginal distribution.

Parameters:
  • noise_prior (Prior, optional) – Prior for noise parameter \(\sigma^2\).

  • noise_constraint (Interval, optional) – Constraint for noise parameter \(\sigma^2\).

  • batch_shape (torch.Size) – The batch shape of the learned noise parameter (default: []).

  • kwargs

Variables:

noise (torch.Tensor) – \(\sigma^2\) parameter (noise)

marginal(function_dist, *args, **kwargs)[source]
Return type:

MultivariateNormal

Returns:

Analytic marginal \(p(\mathbf y)\).

Parameters:

GaussianLikelihoodWithMissingObs

class gpytorch.likelihoods.GaussianLikelihoodWithMissingObs(**kwargs)[source]

The standard likelihood for regression with support for missing values. Assumes a standard homoskedastic noise model:

\[p(y \mid f) = f + \epsilon, \quad \epsilon \sim \mathcal N (0, \sigma^2)\]

where \(\sigma^2\) is a noise parameter. Values of y that are nan do not impact the likelihood calculation.

Note

This likelihood can be used for exact or approximate inference.

Warning

This likelihood is deprecated in favor of gpytorch.settings.observation_nan_policy.

Parameters:
  • noise_prior (Prior, optional) – Prior for noise parameter \(\sigma^2\).

  • noise_constraint (Interval, optional) – Constraint for noise parameter \(\sigma^2\).

  • batch_shape (torch.Size, optional) – The batch shape of the learned noise parameter (default: []).

Variables:

noise (torch.Tensor) – \(\sigma^2\) parameter (noise)

Note

GaussianLikelihoodWithMissingObs has an analytic marginal distribution.

Parameters:

kwargs

marginal(function_dist, *args, **kwargs)[source]
Return type:

MultivariateNormal

Returns:

Analytic marginal \(p(\mathbf y)\).

Parameters:

FixedNoiseGaussianLikelihood

class gpytorch.likelihoods.FixedNoiseGaussianLikelihood(noise, learn_additional_noise=False, batch_shape=torch.Size([]), **kwargs)[source]

A Likelihood that assumes fixed heteroscedastic noise. This is useful when you have fixed, known observation noise for each training example.

Note that this likelihood takes an additional argument when you call it, noise, that adds a specified amount of noise to the passed MultivariateNormal. This allows for adding known observational noise to test data.

Note

This likelihood can be used for exact or approximate inference.

Parameters:
  • noise (torch.Tensor (... x N)) – Known observation noise (variance) for each training example.

  • learn_additional_noise (bool, optional) – Set to true if you additionally want to learn added diagonal noise, similar to GaussianLikelihood.

  • batch_shape (torch.Size, optional) – The batch shape of the learned noise parameter (default []) if learn_additional_noise=True.

Variables:

noise (torch.Tensor) – \(\sigma^2\) parameter (noise)

Note

FixedNoiseGaussianLikelihood has an analytic marginal distribution.

Example

>>> train_x = torch.randn(55, 2)
>>> noises = torch.ones(55) * 0.01
>>> likelihood = FixedNoiseGaussianLikelihood(noise=noises, learn_additional_noise=True)
>>> pred_y = likelihood(gp_model(train_x))
>>>
>>> test_x = torch.randn(21, 2)
>>> test_noises = torch.ones(21) * 0.02
>>> pred_y = likelihood(gp_model(test_x), noise=test_noises)
Parameters:

kwargs

marginal(function_dist, *args, **kwargs)[source]
Return type:

MultivariateNormal

Returns:

Analytic marginal \(p(\mathbf y)\).

Parameters:

DirichletClassificationLikelihood

class gpytorch.likelihoods.DirichletClassificationLikelihood(targets, alpha_epsilon=0.01, learn_additional_noise=False, batch_shape=torch.Size([]), dtype=torch.float32, **kwargs)[source]

A classification likelihood that treats the labels as regression targets with fixed heteroscedastic noise. From Milios et al, NeurIPS, 2018 [https://arxiv.org/abs/1805.10915].

Note

This likelihood can be used for exact or approximate inference.

Parameters:
  • targets (torch.Tensor) – (… x N) Classification labels.

  • alpha_epsilon (float) – Tuning parameter for the scaling of the likeihood targets. We’d suggest 0.01 or setting via cross-validation.

  • learn_additional_noise (bool, optional) – Set to true if you additionally want to learn added diagonal noise, similar to GaussianLikelihood.

  • batch_shape (torch.Size) – The batch shape of the learned noise parameter (default []) if learn_additional_noise=True.

Variables:

noise (torch.Tensor) – \(\sigma^2\) parameter (noise)

Note

DirichletClassificationLikelihood has an analytic marginal distribution.

Example

>>> train_x = torch.randn(55, 1)
>>> labels = torch.round(train_x).long()
>>> likelihood = DirichletClassificationLikelihood(targets=labels, learn_additional_noise=True)
>>> pred_y = likelihood(gp_model(train_x))
>>>
>>> test_x = torch.randn(21, 1)
>>> test_labels = torch.round(test_x).long()
>>> pred_y = likelihood(gp_model(test_x), targets=labels)
Parameters:
marginal(function_dist, *args, **kwargs)[source]
Return type:

MultivariateNormal

Returns:

Analytic marginal \(p(\mathbf y)\).

Parameters:

BernoulliLikelihood

class gpytorch.likelihoods.BernoulliLikelihood[source]

Implements the Bernoulli likelihood used for GP classification, using Probit regression (i.e., the latent function is warped to be in [0,1] using the standard Normal CDF \(\Phi(x)\)). Given the identity \(\Phi(-x) = 1-\Phi(x)\), we can write the likelihood compactly as:

\[\begin{equation*} p(Y=y|f)=\Phi((2y - 1)f) \end{equation*}\]

Note

BernoulliLikelihood has an analytic marginal distribution.

Note

The labels should take values in {0, 1}.

marginal(function_dist, *args, **kwargs)[source]
Return type:

torch.distributions.bernoulli.Bernoulli

Returns:

Analytic marginal \(p(\mathbf y)\).

Parameters:

BetaLikelihood

class gpytorch.likelihoods.BetaLikelihood(batch_shape=torch.Size([]), scale_prior=None, scale_constraint=None)[source]

A Beta likelihood for regressing over percentages.

The Beta distribution is parameterized by \(\alpha > 0\) and \(\beta > 0\) parameters which roughly correspond to the number of prior positive and negative observations. We instead parameterize it through a mixture \(m \in [0, 1]\) and scale \(s > 0\) parameter.

\[\begin{equation*} \alpha = ms, \quad \beta = (1-m)s \end{equation*}\]

The mixture parameter is the output of the GP passed through a logit function \(\sigma(\cdot)\). The scale parameter is learned.

\[p(y \mid f) = \text{Beta} \left( \sigma(f) s , (1 - \sigma(f)) s\right)\]
Parameters:
  • batch_shape (torch.Size) – The batch shape of the learned noise parameter (default: []).

  • scale_prior (Prior, optional) – Prior for scale parameter \(s\).

  • scale_constraint (Interval, optional) – Constraint for scale parameter \(s\).

Variables:

scale (torch.Tensor) – \(s\) parameter (scale)

LaplaceLikelihood

class gpytorch.likelihoods.LaplaceLikelihood(batch_shape=torch.Size([]), noise_prior=None, noise_constraint=None)[source]

A Laplace likelihood/noise model for GP regression. It has one learnable parameter: \(\sigma\) - the noise

Parameters:
  • batch_shape (torch.Size) – The batch shape of the learned noise parameter (default: []).

  • noise_prior (Prior, optional) – Prior for noise parameter \(\sigma\).

  • noise_constraint (Interval, optional) – Constraint for noise parameter \(\sigma\).

Variables:

noise (torch.Tensor) – \(\sigma\) parameter (noise)

StudentTLikelihood

class gpytorch.likelihoods.StudentTLikelihood(batch_shape=torch.Size([]), deg_free_prior=None, deg_free_constraint=None, noise_prior=None, noise_constraint=None)[source]

A Student T likelihood/noise model for GP regression. It has two learnable parameters: \(\nu\) - the degrees of freedom, and \(\sigma^2\) - the noise

Parameters:
  • batch_shape (torch.Size) – The batch shape of the learned noise parameter (default: []).

  • noise_prior (Prior, optional) – Prior for noise parameter \(\sigma^2\).

  • noise_constraint (Interval, optional) – Constraint for noise parameter \(\sigma^2\).

  • deg_free_prior (Prior, optional) – Prior for deg_free parameter \(\nu\).

  • deg_free_constraint (Interval, optional) – Constraint for deg_free parameter \(\nu\).

Variables:
  • deg_free (torch.Tensor) – \(\nu\) parameter (degrees of freedom)

  • noise (torch.Tensor) – \(\sigma^2\) parameter (noise)

Multi-Dimensional Likelihoods

Likelihoods for GPs that are distributions of vector-valued functions. (I.e. for a specific \(\mathbf x\) we expect that \(f(\mathbf x) \in \mathbb{R}^t\), where \(t\) is the number of output dimensions.)

MultitaskGaussianLikelihood

class gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks, rank=0, batch_shape=torch.Size([]), task_prior=None, noise_prior=None, noise_constraint=None, has_global_noise=True, has_task_noise=True)[source]

A convenient extension of the GaussianLikelihood to the multitask setting that allows for a full cross-task covariance structure for the noise. The fitted covariance matrix has rank rank. If a strictly diagonal task noise covariance matrix is desired, then rank=0 should be set. (This option still allows for a different noise parameter for each task.)

Like the Gaussian likelihood, this object can be used with exact inference.

Note

At least one of has_global_noise or has_task_noise should be specified.

Note

MultittaskGaussianLikelihood has an analytic marginal distribution.

Parameters:
  • num_tasks (int) – Number of tasks.

  • noise_covar – A model for the noise covariance. This can be a simple homoskedastic noise model, or a GP that is to be fitted on the observed measurement errors.

  • rank (int) – The rank of the task noise covariance matrix to fit. If rank is set to 0, then a diagonal covariance matrix is fit.

  • task_prior (Prior, optional) – Prior to use over the task noise correlation matrix. Only used when \(\text{rank} > 0\).

  • batch_shape (torch.Size) – Number of batches.

  • has_global_noise (bool) – Whether to include a \(\sigma^2 \mathbf I_{nt}\) term in the noise model.

  • has_task_noise (bool) – Whether to include task-specific noise terms, which add \(\mathbf I_n \otimes \mathbf D_T\) into the noise model.

  • noise_prior (Prior, optional) –

  • noise_constraint (Interval, optional) –

Variables:
  • task_noise_covar (torch.Tensor) – The inter-task noise covariance matrix

  • task_noises (torch.Tensor) – (Optional) task specific noise variances (added onto the task_noise_covar)

  • noise (torch.Tensor) – (Optional) global noise variance (added onto the task_noise_covar)

marginal(function_dist, *args, **kwargs)[source]
Return type:

MultitaskMultivariateNormal

Returns:

Analytic marginal \(p(\mathbf y)\).

Parameters:

SoftmaxLikelihood

class gpytorch.likelihoods.SoftmaxLikelihood(num_features=None, num_classes=None, mixing_weights=True, mixing_weights_prior=None)[source]

Implements the Softmax (multiclass) likelihood used for GP classification.

\[p(\mathbf y \mid \mathbf f) = \text{Softmax} \left( \mathbf W \mathbf f \right)\]

\(\mathbf W\) is a set of linear mixing weights applied to the latent functions \(\mathbf f\).

Parameters:
  • num_features (int, optional) – Dimensionality of latent function \(\mathbf f\).

  • num_classes (int, optional) – Number of classes.

  • mixing_weights (bool) – (Default: True) Whether to learn a linear mixing weight \(\mathbf W\) applied to the latent function \(\mathbf f\). If False, then \(\mathbf W = \mathbf I\).

  • mixing_weights_prior (Prior, optional) – Prior to use over the mixing weights \(\mathbf W\).

Variables:

mixing_weights (torch.Tensor) – (Optional) mixing weights.