Source code for gpytorch.lazy.lazy_evaluated_kernel_tensor

#!/usr/bin/env python3

import torch

from .. import settings, beta_features
from ..utils import broadcasting
from ..utils.getitem import _noop_index
from ..utils.memoize import cached
from .lazy_tensor import LazyTensor
from .non_lazy_tensor import lazify


[docs]class LazyEvaluatedKernelTensor(LazyTensor): _check_size = False def _check_args(self, x1, x2, kernel, last_dim_is_batch=False, **params): if not torch.is_tensor(x1): return "x1 must be a tensor. Got {}".format(x1.__class__.__name__) if not torch.is_tensor(x2): return "x1 must be a tensor. Got {}".format(x1.__class__.__name__) def __init__(self, x1, x2, kernel, last_dim_is_batch=False, **params): super(LazyEvaluatedKernelTensor, self).__init__( x1, x2, kernel=kernel, last_dim_is_batch=last_dim_is_batch, **params ) self.kernel = kernel self.x1 = x1 self.x2 = x2 self.last_dim_is_batch = last_dim_is_batch self.params = params @property def dtype(self): return self.kernel.dtype @property def device(self): return self.x1.device def _expand_batch(self, batch_shape): return self.evaluate_kernel()._expand_batch(batch_shape) def _getitem(self, row_index, col_index, *batch_indices): x1 = self.x1 x2 = self.x2 num_outs_per_in = self.kernel.num_outputs_per_input(x1, x2) # The row index and col index should exactly correspond to which entries of x1 and x2 we need # So we'll basically call x1[*batch_indices, row_index, :], x2[*batch_indices, col_index, :] # However - if we have multiple outputs per input, then the indices won't directly # correspond to the entries of row/col. We'll have to do a little pre-processing if num_outs_per_in != 1: if not isinstance(x1, slice) or not isinstance(x2, slice): # It's too complicated to deal with tensor indices in this case - we'll use the super method return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices) # Now we know that x1 and x2 are slices # Let's make sure that the slice dimensions perfectly correspond with the number of # outputs per input that we have row_start, row_end, row_step = row_index.start, row_index.stop, row_index.step col_start, col_end, col_step = col_index.start, col_index.stop, col_index.step if row_step is not None or col_step is not None: return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices) if ( (row_start % num_outs_per_in) or (col_start % num_outs_per_in) or (row_end % num_outs_per_in) or (col_end % num_outs_per_in) ): return self.evaluate_kernel()._getitem(row_index, col_index, *batch_indices) # Otherwise - let's divide the slices by the number of outputs per input row_index = slice(row_start // num_outs_per_in, row_end // num_outs_per_in, None) col_index = slice(col_start // num_outs_per_in, col_end // num_outs_per_in, None) # Define the index we're using for the last index # If the last index corresponds to a batch, then we'll use the appropriate batch_index # Otherwise, we'll use the _noop_index if self.last_dim_is_batch: *batch_indices, dim_index = batch_indices else: dim_index = _noop_index # Get the indices of x1 and x2 that matter for the kernel # Call x1[*batch_indices, row_index, :] try: x1 = x1[(*batch_indices, row_index, dim_index)] # We're going to handle multi-batch indexing with a try-catch loop # This way - in the default case, we can avoid doing expansions of x1 which can be timely except IndexError: if isinstance(batch_indices, slice): x1 = x1.expand(1, *self.x1.shape[-2:])[(*batch_indices, row_index, dim_index)] elif isinstance(batch_indices, tuple): if any(not isinstance(bi, slice) for bi in batch_indices): raise RuntimeError( f"Attempting to tensor index a non-batch matrix's batch dimensions. " "Got batch index {batch_indices} but my shape was {self.shape}" ) x1 = x1.expand(*([1] * len(batch_indices)), *self.x1.shape[-2:]) x1 = x1[(*batch_indices, row_index, dim_index)] # Call x2[*batch_indices, row_index, :] try: x2 = x2[(*batch_indices, col_index, dim_index)] # We're going to handle multi-batch indexing with a try-catch loop # This way - in the default case, we can avoid doing expansions of x1 which can be timely except IndexError: if isinstance(batch_indices, slice): x2 = x2.expand(1, *self.x2.shape[-2:])[(*batch_indices, row_index, dim_index)] elif isinstance(batch_indices, tuple): if any([not isinstance(bi, slice) for bi in batch_indices]): raise RuntimeError( f"Attempting to tensor index a non-batch matrix's batch dimensions. " "Got batch index {batch_indices} but my shape was {self.shape}" ) x2 = x2.expand(*([1] * len(batch_indices)), *self.x2.shape[-2:]) x2 = x2[(*batch_indices, row_index, dim_index)] # Now construct a kernel with those indices return self.__class__(x1, x2, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params) def _matmul(self, rhs): # This _matmul is defined computes the kernel in chunks # It is only used when we are using kernel checkpointing # It won't be called if checkpointing is off x1 = self.x1 x2 = self.x2 split_size = beta_features.checkpoint_kernel.value() if not split_size: raise RuntimeError( "Should not have ended up in LazyEvaluatedKernelTensor._matmul without kernel checkpointing. " "This is probably a bug in GPyTorch." ) with torch.no_grad(), settings.lazily_evaluate_kernels(False): sub_x1s = torch.split(x1, split_size, dim=-2) res = [] for sub_x1 in sub_x1s: sub_kernel_matrix = lazify( self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params) ) res.append(sub_kernel_matrix._matmul(rhs)) res = torch.cat(res, dim=-2) return res def _quad_form_derivative(self, left_vecs, right_vecs): # This _quad_form_derivative computes the kernel in chunks # It is only used when we are using kernel checkpointing # It won't be called if checkpointing is off split_size = beta_features.checkpoint_kernel.value() if not split_size: raise RuntimeError( "Should not have ended up in LazyEvaluatedKernelTensor._quad_form_derivative without kernel " "checkpointing. This is probably a bug in GPyTorch." ) x1 = self.x1.detach() x2 = self.x2.detach() # Break objects into chunks sub_x1s = torch.split(x1, split_size, dim=-2) sub_left_vecss = torch.split(left_vecs, split_size, dim=-2) # Compute the gradient in chunks for sub_x1, sub_left_vecs in zip(sub_x1s, sub_left_vecss): with torch.enable_grad(), settings.lazily_evaluate_kernels(False): sub_kernel_matrix = lazify( self.kernel(sub_x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params) ) sub_grad_outputs = tuple(sub_kernel_matrix._quad_form_derivative(sub_left_vecs, right_vecs)) sub_kernel_outputs = tuple(sub_kernel_matrix.representation()) torch.autograd.backward(sub_kernel_outputs, sub_grad_outputs) return x1.grad, x2.grad @cached(name="size") def _size(self): if settings.debug.on(): if hasattr(self.kernel, "size"): raise RuntimeError("Kernels must define `num_outputs_per_input` and should not define `size`") x1 = self.x1 x2 = self.x2 num_outputs_per_input = self.kernel.num_outputs_per_input(x1, x2) num_rows = x1.size(-2) * num_outputs_per_input num_cols = x2.size(-2) * num_outputs_per_input # Default case - when we're not using broadcasting # We write this case special for efficiency if x1.shape[:-2] == x2.shape[:-2] and x1.shape[:-2] == self.kernel.batch_shape: expected_size = self.kernel.batch_shape + torch.Size((num_rows, num_cols)) # When we're using broadcasting else: expected_size = broadcasting._matmul_broadcast_shape( torch.Size([*x1.shape[:-2], num_rows, x1.size(-1)]), torch.Size([*x2.shape[:-2], x2.size(-1), num_cols]), error_msg="x1 and x2 were not broadcastable to a proper kernel shape. " "Got x1.shape = {} and x2.shape = {}".format(str(x1.shape), str(x2.shape)) ) expected_size = broadcasting._mul_broadcast_shape( expected_size[:-2], self.kernel.batch_shape, error_msg=( f"x1 and x2 were not broadcastable with kernel of batch_shape {self.kernel.batch_shape}. " f"Got x1.shape = {x1.shape} and x2.shape = {x2.shape}" ) ) + expected_size[-2:] # Handle when the last dim is batch if self.last_dim_is_batch: expected_size = expected_size[:-2] + x1.shape[-1:] + expected_size[-2:] return expected_size def _transpose_nonbatch(self): return self.__class__( self.x2, self.x1, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params ) def add_jitter(self, jitter_val=1e-3): return self.evaluate_kernel().add_jitter(jitter_val) def _unsqueeze_batch(self, dim): return self.evaluate_kernel()._unsqueeze_batch(dim)
[docs] @cached(name="kernel_diag") def diag(self): """ Getting the diagonal of a kernel can be handled more efficiently by transposing the batch and data dimension before calling the kernel. Implementing it this way allows us to compute predictions more efficiently in cases where only the variances are required. """ from ..kernels import Kernel x1 = self.x1 x2 = self.x2 res = super(Kernel, self.kernel).__call__( x1, x2, diag=True, last_dim_is_batch=self.last_dim_is_batch, **self.params ) # Now we'll make sure that the shape we're getting from diag makes sense if settings.debug.on(): expected_shape = self.shape[:-1] if res.shape != expected_shape: raise RuntimeError( "The kernel {} is not equipped to handle and diag. Expected size {}. " "Got size {}".format(self.__class__.__name__, expected_shape, res.shape) ) if isinstance(res, LazyTensor): res = res.evaluate() return res.view(self.shape[:-1]).contiguous()
[docs] @cached(name="kernel_eval") def evaluate_kernel(self): """ NB: This is a meta LazyTensor, in the sense that evaluate can return a LazyTensor if the kernel being evaluated does so. """ x1 = self.x1 x2 = self.x2 with settings.lazily_evaluate_kernels(False): temp_active_dims = self.kernel.active_dims self.kernel.active_dims = None res = self.kernel( x1, x2, diag=False, last_dim_is_batch=self.last_dim_is_batch, **self.params ) self.kernel.active_dims = temp_active_dims # Check the size of the output if settings.debug.on(): if res.shape != self.shape: raise RuntimeError( f"The expected shape of the kernel was {self.shape}, but got {res.shape}. " "This is likely a bug in GPyTorch." ) return lazify(res)
@cached def evaluate(self): return self.evaluate_kernel().evaluate() def ndimension(self): # TODO: fix this with the remove_batch_dim PR return max(self.x2.dim(), self.x2.dim()) def repeat(self, *repeats): if len(repeats) == 1 and hasattr(repeats[0], "__iter__"): repeats = repeats[0] *batch_repeat, row_repeat, col_repeat = repeats x1 = self.x1.repeat(*batch_repeat, row_repeat, 1) x2 = self.x2.repeat(*batch_repeat, col_repeat, 1) return self.__class__(x1, x2, kernel=self.kernel, last_dim_is_batch=self.last_dim_is_batch, **self.params) def representation(self): # If we're checkpointing the kernel, we'll use chunked _matmuls defined in LazyEvaluatedKernelTensor if beta_features.checkpoint_kernel.value(): return super().representation() # Otherwise, we'll evaluate the kernel (or at least its LazyTensor representation) and use its # representation else: return self.evaluate_kernel().representation() def representation_tree(self): # If we're checkpointing the kernel, we'll use chunked _matmuls defined in LazyEvaluatedKernelTensor if beta_features.checkpoint_kernel.value(): return super().representation_tree() # Otherwise, we'll evaluate the kernel (or at least its LazyTensor representation) and use its # representation else: return self.evaluate_kernel().representation_tree() def __getitem__(self, index): """ Supports subindexing of the matrix this LazyTensor represents. This may return either another :obj:`gpytorch.lazy.LazyTensor` or a :obj:`torch.tensor` depending on the exact implementation. """ # Process the index index = index if isinstance(index, tuple) else (index,) # Special case for the most common case: [..., slice, slice] if len(index) == 3 and index[0] is Ellipsis and isinstance(index[1], slice) and isinstance(index[2], slice): _, row_index, col_index = index batch_indices = [slice(None, None, None)] * (self.dim() - 2) return self._getitem(row_index, col_index, *batch_indices) else: return super().__getitem__(index)