Source code for pytorch_ood.model.centers

import logging

import torch
from torch import nn

from .. import utils

log = logging.getLogger(__name__)


[docs] class ClassCenters(nn.Module): """ Several methods for OOD Detection propose to model a center :math:`\\mu_y` for each class. These centers are either static, or learned via gradient descent. The centers are also known as class proxy, class prototype or class anchor. """ def __init__(self, n_classes: int, n_features: int, fixed: bool = False): """ :param n_classes: number of classes vectors :param n_features: dimensionality of the space in which the centers live :param fixed: False if the centers should be learnable parameters, True if they should be fixed at their initial position """ super(ClassCenters, self).__init__() # anchor points are fixed, so they do not require gradients self._params = nn.Parameter(torch.randn(size=(n_classes, n_features))) if fixed: self._params.requires_grad = False @property def num_classes(self) -> int: return self.params.shape[0] @property def n_features(self) -> int: return self.params.shape[1] @property def params(self) -> nn.Parameter: """ Class centers :math:`\\mu` """ return self._params
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ :param x: samples :returns: pairwise squared distance of samples to each center """ assert x.shape[1] == self.n_features return utils.pairwise_distances(x, self.params)
[docs] def predict(self, x: torch.Tensor) -> torch.Tensor: """ Make class membership predictions based on the softmin of the distances to each center. :param x: embeddings of samples :returns: normalized pairwise distance of samples to each center """ distances = utils.pairwise_distances(x, self.params) return nn.functional.softmin(distances, dim=1)
[docs] class RunningCenters(nn.Module): """ Estimates class centers from batches of data using a running mean estimator. """ def __init__(self, n_classes: int, n_embedding: int): """ :param n_classes: number of centers :param n_embedding: dimensionality of embedding space """ super(RunningCenters, self).__init__() self.num_classes = n_classes self.n_embedding = n_embedding # create buffer for centers. those buffers will be updated during training, and are fixed during evaluation running_centers = torch.empty( size=(self.num_classes, self.n_embedding), requires_grad=False ).float() num_batches_tracked = torch.empty(size=(1,), requires_grad=False).float() self.register_buffer("running_centers", running_centers) self.register_buffer("num_batches_tracked", num_batches_tracked) self.reset() @property def centers(self) -> torch.Tensor: """ :return: current class center estimates """ return self.running_centers
[docs] def reset(self) -> None: """ Resets the running stats of online class center estimates. """ log.info("Reset running stats") nn.init.zeros_(self.running_centers) nn.init.zeros_(self.num_batches_tracked)
def calculate_centers(self, embeddings, target) -> torch.Tensor: mu = torch.full( size=(self.num_classes, self.n_embedding), fill_value=float("NaN"), device=embeddings.device, ) for clazz in target.unique(sorted=False): mu[clazz] = embeddings[target == clazz].mean(dim=0) # all instances of this class return mu
[docs] def update(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Update running centers :param x: inputs :param target: class labels :return: per class mean of inputs """ batch_classes = torch.unique(target, sorted=False) n_instances = x.shape[0] # calculate empirical centers mu = self.calculate_centers(x, target) # update running mean centers cma = mu[batch_classes] + self.running_centers[batch_classes] * self.num_batches_tracked self.running_centers[batch_classes] = cma / (self.num_batches_tracked + 1) self.num_batches_tracked += 1 return mu
[docs] def forward(self, x: torch.Tensor): """ Calculates distances to centers :param x: :return: distance matrix """ return utils.pairwise_distances(self.centers, x)