Source code for gpytorch.kernels.index_kernel

#!/usr/bin/env python3

from typing import Optional

import torch
from linear_operator.operators import (
    DiagLinearOperator,
    InterpolatedLinearOperator,
    PsdSumLinearOperator,
    RootLinearOperator,
)

from ..constraints import Interval, Positive
from ..priors import Prior
from .kernel import Kernel


[docs]class IndexKernel(Kernel): r""" A kernel for discrete indices. Kernel is defined by a lookup table. .. math:: \begin{equation} k(i, j) = \left(BB^\top + \text{diag}(\mathbf v) \right)_{i, j} \end{equation} where :math:`B` is a low-rank matrix, and :math:`\mathbf v` is a non-negative vector. These parameters are learned. Args: num_tasks (int): Total number of indices. batch_shape (torch.Size, optional): Set if the MultitaskKernel is operating on batches of data (and you want different parameters for each batch) rank (int): Rank of :math:`B` matrix. Controls the degree of correlation between the outputs. With a rank of 1 the outputs are identical except for a scaling factor. prior (:obj:`gpytorch.priors.Prior`): Prior for :math:`B` matrix. var_constraint (Constraint, optional): Constraint for added diagonal component. Default: `Positive`. Attributes: covar_factor: The :math:`B` matrix. raw_var: The element-wise log of the :math:`\mathbf v` vector. """ def __init__( self, num_tasks: int, rank: Optional[int] = 1, prior: Optional[Prior] = None, var_constraint: Optional[Interval] = None, **kwargs, ): if rank > num_tasks: raise RuntimeError("Cannot create a task covariance matrix larger than the number of tasks") super().__init__(**kwargs) if var_constraint is None: var_constraint = Positive() self.register_parameter( name="covar_factor", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks, rank)) ) self.register_parameter(name="raw_var", parameter=torch.nn.Parameter(torch.randn(*self.batch_shape, num_tasks))) if prior is not None: if not isinstance(prior, Prior): raise TypeError("Expected gpytorch.priors.Prior but got " + type(prior).__name__) self.register_prior("IndexKernelPrior", prior, lambda m: m._eval_covar_matrix()) self.register_constraint("raw_var", var_constraint) @property def var(self): return self.raw_var_constraint.transform(self.raw_var) @var.setter def var(self, value): self._set_var(value) def _set_var(self, value): self.initialize(raw_var=self.raw_var_constraint.inverse_transform(value)) def _eval_covar_matrix(self): cf = self.covar_factor return cf @ cf.transpose(-1, -2) + torch.diag_embed(self.var) @property def covar_matrix(self): var = self.var res = PsdSumLinearOperator(RootLinearOperator(self.covar_factor), DiagLinearOperator(var)) return res def forward(self, i1, i2, **params): i1, i2 = i1.long(), i2.long() covar_matrix = self._eval_covar_matrix() batch_shape = torch.broadcast_shapes(i1.shape[:-2], i2.shape[:-2], self.batch_shape) res = InterpolatedLinearOperator( base_linear_op=covar_matrix, left_interp_indices=i1.expand(batch_shape + i1.shape[-2:]), right_interp_indices=i2.expand(batch_shape + i2.shape[-2:]), ) return res