Source code for gpytorch.kernels.keops.matern_kernel

#!/usr/bin/env python3
import math

from linear_operator.operators import KernelLinearOperator

from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, nu=2.5, **params):
    x1_, x2_ = _lazify_and_expand_inputs(x1, x2)

    sq_distance = ((x1_ - x2_) ** 2).sum(-1)
    distance = (sq_distance + 1e-20).sqrt()
    # ^^ Need to add epsilon to prevent small negative values with the sqrt
    # backward pass (otherwise we get NaNs).
    # using .clamp(1e-20, math.inf) doesn't work in KeOps; it also creates NaNs
    exp_component = (-math.sqrt(nu * 2) * distance).exp()

    if nu == 0.5:
        constant_component = 1
    elif nu == 1.5:
        constant_component = (math.sqrt(3) * distance) + 1
    elif nu == 2.5:
        constant_component = (math.sqrt(5) * distance) + (1 + 5.0 / 3.0 * sq_distance)

    return constant_component * exp_component


[docs]class MaternKernel(KeOpsKernel): """ Implements the Matern kernel using KeOps as a driver for kernel matrix multiplies. This class can be used as a drop in replacement for :class:`gpytorch.kernels.MaternKernel` in most cases, and supports the same arguments. :param nu: (Default: 2.5) The smoothness parameter. :type nu: float (0.5, 1.5, or 2.5) :param ard_num_dims: (Default: `None`) Set this if you want a separate lengthscale for each input dimension. It should be `d` if x1 is a `... x n x d` matrix. :type ard_num_dims: int, optional :param batch_shape: (Default: `None`) Set this if you want a separate lengthscale for each batch of input data. It should be `torch.Size([b1, b2])` for a `b1 x b2 x n x m` kernel output. :type batch_shape: torch.Size, optional :param active_dims: (Default: `None`) Set this if you want to compute the covariance of only a few input dimensions. The ints corresponds to the indices of the dimensions. :type active_dims: Tuple(int) :param lengthscale_prior: (Default: `None`) Set this if you want to apply a prior to the lengthscale parameter. :type lengthscale_prior: ~gpytorch.priors.Prior, optional :param lengthscale_constraint: (Default: `Positive`) Set this if you want to apply a constraint to the lengthscale parameter. :type lengthscale_constraint: ~gpytorch.constraints.Interval, optional :param eps: (Default: 1e-6) The minimum value that the lengthscale can take (prevents divide by zero errors). :type eps: float, optional """ has_lengthscale = True def __init__(self, nu=2.5, **kwargs): if nu not in {0.5, 1.5, 2.5}: raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5") super().__init__(**kwargs) self.nu = nu def forward(self, x1, x2, **kwargs): mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)] x1_ = (x1 - mean) / self.lengthscale x2_ = (x2 - mean) / self.lengthscale # return KernelLinearOperator inst only when calculating the whole covariance matrix return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs)