Source code for pytorch_ood.detector.nnguide

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: segmentation badge
.. image:: https://img.shields.io/badge/AI_Coded-yes-blue?style=flat-square
   :alt: slop-badge

..  autoclass:: pytorch_ood.detector.NNGuide
    :members:

"""

import logging
from typing import Callable, Optional, TypeVar

import torch
from torch import Tensor
from torch.utils.data import DataLoader

from pytorch_ood.api import Detector, ModelNotSetException, RequiresFittingException
from pytorch_ood.utils import extract_features, is_known

log = logging.getLogger(__name__)
Self = TypeVar("Self")


[docs] class NNGuide(Detector): """ Implements *Nearest Neighbor Guidance for Out-of-Distribution Detection*. Guides classifier-based scores using k-NN similarity to an energy-weighted feature bank. The feature bank is constructed by scaling in-distribution training features with their corresponding energy scores. At inference, the outlier score is the negated product of the k-NN guidance (mean inner product with the energy-scaled feature bank) and the sample's own energy: .. math:: s(x) = - \\underbrace{\\frac{1}{k} \\sum_{z \\in \\mathcal{N}_k(x)} \\langle f(x),\\, E(z) \\cdot f(z) \\rangle}_{\\text{guidance}} \\cdot E(x) where :math:`E(x) = \\log \\sum_i \\exp(l_i(x))` is the energy score, :math:`f(x)` are the penultimate-layer features, and :math:`\\mathcal{N}_k(x)` are the :math:`k` nearest neighbors in the energy-scaled feature bank measured by inner product. The model passed to the constructor should extract penultimate-layer features. The classification head weights ``w`` and biases ``b`` are used internally to compute logits from features, similar to :class:`ViM`. :see Paper: `arXiv <https://arxiv.org/abs/2309.14888>`__ """ def __init__( self, model: Callable[[Tensor], Tensor], w: Tensor, b: Tensor, k: int = 10, ): """ :param model: neural network that extracts penultimate-layer features :param w: weight matrix of the classification head, shape ``(num_classes, feature_dim)`` :param b: bias vector of the classification head, shape ``(num_classes,)`` :param k: number of nearest neighbors for guidance (default: 10) """ super(NNGuide, self).__init__() self.model = model self.w = w.detach().cpu().float() self.b = b.detach().cpu().float() self.k = k self._scaled_features: Optional[Tensor] = None def _logits(self, features: Tensor) -> Tensor: """Compute logits from features using the stored classification head.""" return features @ self.w.T + self.b
[docs] def fit(self: Self, data_loader: DataLoader, device=None) -> Self: """ Extract features from the data loader and build the energy-scaled feature bank. :param data_loader: data loader with ID training data :param device: device for feature extraction. If ``None``, inferred from model. """ if device is None: if isinstance(self.model, torch.nn.Module): device = next(self.model.parameters()).device else: device = "cpu" log.warning(f"No device given. Will use '{device}'.") if isinstance(self.model, torch.nn.Module): log.debug(f"Moving model to {device}") self.model.to(device) z, y = extract_features(model=self.model, data_loader=data_loader, device=device) return self.fit_features(z, y)
[docs] def fit_features(self: Self, z: Tensor, labels: Tensor) -> Self: """ Build the energy-scaled feature bank from pre-extracted features. :param z: features, shape ``(n, feature_dim)`` :param labels: corresponding labels """ known = is_known(labels) if not known.any(): raise ValueError("No ID samples") z = z[known].cpu().float() logits = self._logits(z) energy = torch.logsumexp(logits, dim=1) self._scaled_features = z * energy[:, None] return self
[docs] def predict(self, x: Tensor) -> Tensor: """ :param x: model inputs """ if self.model is None: raise ModelNotSetException() if self._scaled_features is None: raise RequiresFittingException() with torch.no_grad(): z = self.model(x) return self.predict_features(z)
[docs] def predict_features(self, z: Tensor) -> Tensor: """ Compute the NNGuide outlier score from pre-extracted features. :param z: features, shape ``(batch, feature_dim)`` """ if self._scaled_features is None: raise RequiresFittingException() z = z.detach().cpu().float() logits = self._logits(z) energy = torch.logsumexp(logits, dim=1) # inner product between test features and energy-scaled training features # shape: (batch, n_train) sim = z @ self._scaled_features.T # mean inner product over k nearest neighbors (largest inner products) topk_sim, _ = sim.topk(self.k, dim=1) guidance = topk_sim.mean(dim=1) # higher guidance * energy = more in-distribution, so negate for outlier score return -(guidance * energy)