Source code for pytorch_ood.loss.cac

"""
CACLoss
----------------------------------------------

..  automodule:: pytorch_ood.nn.loss.cac
    :members: cac_rejection_score, CACLoss

"""
import torch as torch
import torch.nn as nn

#
from torch.nn import functional as F

from ..model.centers import ClassCenters
from ..utils import is_known


[docs] class CACLoss(nn.Module): """ Class Anchor Clustering Loss from the paper *Class Anchor Clustering: a Distance-based Loss for Training Open Set Classifiers* :see Paper: `WACV 2022 <https://arxiv.org/abs/2004.02434>`_ :see Implementation: `GitHub <https://github.com/dimitymiller/cac-openset/>`_ """ def __init__(self, n_classes: int, magnitude: float = 1.0, alpha: float = 1.0): """ Centers are initialized as unit vectors, scaled by the magnitude. :param n_classes: number of classes :math:`C` :param magnitude: magnitude of class anchors :param alpha: :math:`\\alpha` weight for anchor term """ super(CACLoss, self).__init__() self.n_classes = n_classes self.magnitude = magnitude self.alpha = alpha # anchor points are fixed, so they do not require gradients self._centers = ClassCenters(n_classes, n_classes, fixed=True) self._init_centers() @property def centers(self) -> ClassCenters: """ The class centers :math:`\\mu_y`. """ return self._centers def _init_centers(self) -> None: """Init anchors with 1, scale by magnitude""" nn.init.eye_(self.centers.params) self.centers.params.data *= self.magnitude
[docs] def forward(self, distances: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Calculates the CAC loss, based on the given distanc matrix and target labels. OOD inputs will be ignored. :param distances: matrix of distances of each point to each center with shape :math:`B \\times C`. :param target: labels for samples """ assert distances.shape[1] == self.n_classes known = is_known(target) if known.any(): target_known = target[known] d_known = distances[known] d_true = torch.gather(input=d_known, dim=1, index=target_known.view(-1, 1)).view(-1) anchor_loss = d_true.mean() non_target = torch.arange( 0, self.n_classes - 1, dtype=torch.long, device=distances.device ).expand(d_known.shape[0], self.n_classes - 1) # required in newer versions of torch, before advances indexing non_target = non_target.clone() is_last_class = target_known == self.n_classes non_target[is_last_class, target_known[is_last_class]] = self.n_classes - 1 d_other = torch.gather(d_known, dim=1, index=non_target) # for numerical stability, we clamp the distance values tuplet_loss = (-d_other + d_true.unsqueeze(1)).clamp(max=50).exp() tuplet_loss = torch.log(1 + tuplet_loss.sum(dim=1)).mean() else: anchor_loss = torch.tensor(0.0, device=distances.device) tuplet_loss = torch.tensor(0.0, device=distances.device) return self.alpha * anchor_loss + tuplet_loss
[docs] def distance(self, x: torch.Tensor) -> torch.Tensor: """ :param x: input points :return: matrix with squared distances from each point to each center with shape :math:`B \\times C`. """ return self.centers(x)
[docs] @staticmethod def score(distance: torch.Tensor) -> torch.Tensor: """ Rejection score proposed in the paper. :param distance: distance of instances to class centers :return: outlier scores """ scores = distance * (1 - F.softmin(distance, dim=1)) return -scores.max(dim=1).values