Source code for pytorch_ood.detector.nnguide

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: segmentation badge
.. image:: https://img.shields.io/badge/AI_Coded-yes-blue?style=flat-square
   :alt: slop-badge

..  autoclass:: pytorch_ood.detector.NNGuide
    :members:

"""

import logging
from typing import Callable, Optional, TypeVar

import torch
from torch import Tensor
from torch.utils.data import DataLoader

from pytorch_ood.api import FeaturesDetector, ModelNotSetException, RequiresFittingException
from pytorch_ood.utils import extract_features, is_known

log = logging.getLogger(__name__)
Self = TypeVar("Self")


[docs] class NNGuide(FeaturesDetector): """ Implements *Nearest Neighbor Guidance for Out-of-Distribution Detection*. Guides classifier-based scores using k-NN similarity to an energy-weighted feature bank. The feature bank is constructed by scaling in-distribution training features with their corresponding energy scores. At inference, the outlier score is the negated product of the k-NN guidance (mean inner product with the energy-scaled feature bank) and the sample's own energy: .. math:: s(x) = - \\underbrace{\\frac{1}{k} \\sum_{z \\in \\mathcal{N}_k(x)} \\langle f(x),\\, E(z) \\cdot f(z) \\rangle}_{\\text{guidance}} \\cdot E(x) where :math:`E(x) = \\log \\sum_i \\exp(l_i(x))` is the energy score, :math:`f(x)` are the penultimate-layer features, and :math:`\\mathcal{N}_k(x)` are the :math:`k` nearest neighbors in the energy-scaled feature bank measured by inner product. The encoder extracts penultimate-layer features. The head computes logits from features and is used internally to compute energy scores, similar to :class:`ViM`. Example Code: .. code :: python model = WideResNet() detector = NNGuide( encoder=model.features, head=model.fc, k=10 ) detector.fit(train_loader) scores = detector(images) :see Paper: `arXiv <https://arxiv.org/abs/2309.14888>`__ """ requires_fit = True def __init__( self, encoder: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor], k: int = 10, ): """ :param encoder: neural network that extracts penultimate-layer features :param head: callable that maps features to logits (e.g., a linear layer) :param k: number of nearest neighbors for guidance (default: 10) """ super(NNGuide, self).__init__() self.encoder = encoder self.head = head self.k = k self._scaled_features: Optional[Tensor] = None try: from sklearn.neighbors import NearestNeighbors except ImportError: raise ImportError("You have to install scikit-learn to use this detector") self._nbrs = NearestNeighbors(n_neighbors=k, metric="cosine", n_jobs=-1)
[docs] def fit(self: Self, data_loader: DataLoader, device=None) -> Self: """ Extract features from the data loader and build the energy-scaled feature bank. :param data_loader: data loader with ID training data :param device: device for feature extraction. If ``None``, inferred from encoder. """ if device is None: device = self.device if device is None: if isinstance(self.encoder, torch.nn.Module): device = next(self.encoder.parameters()).device else: device = "cpu" log.warning(f"No device given. Will use '{device}'.") if isinstance(self.encoder, torch.nn.Module): log.debug(f"Moving model to {device}") self.encoder.to(device) z, y = extract_features(model=self.encoder, data_loader=data_loader, device=device) return self.fit_features(z, y)
[docs] def fit_features(self: Self, z: Tensor, labels: Tensor) -> Self: """ Build the energy-scaled feature bank from pre-extracted features. :param z: features, shape ``(n, feature_dim)`` :param labels: corresponding labels """ known = is_known(labels) if not known.any(): raise ValueError("No ID samples") device = self.device or z.device z = z[known].detach().to(device).float() if isinstance(self.head, torch.nn.Module): self.head.to(device) with torch.no_grad(): logits = self.head(z) energy = torch.logsumexp(logits, dim=1) # Normalize energy to avoid numerical issues from large scaling factors # Use z/||z|| to normalize features, then scale by normalized energy z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-8) energy_norm = energy / (energy.mean() + 1e-8) self._scaled_features = z_norm * energy_norm[:, None] # Fit k-NN model on normalized features for efficient neighbor search self._nbrs.fit(self._scaled_features.detach().cpu().numpy()) return self
[docs] def predict(self, x: Tensor) -> Tensor: """ :param x: model inputs """ if self.encoder is None: raise ModelNotSetException() if self._scaled_features is None: raise RequiresFittingException() with torch.no_grad(): z = self.encoder(x) return self.predict_features(z)
@torch.no_grad() def predict_features(self, z: Tensor) -> Tensor: """ Compute the NNGuide outlier score from pre-extracted features. :param z: features, shape ``(batch, feature_dim)`` """ if self._scaled_features is None: raise RequiresFittingException() device = self.device or z.device if isinstance(self.head, torch.nn.Module): self.head.to(device) logits = self.head(z) energy = torch.logsumexp(logits, dim=1) # Normalize features for numerical stability z_norm = z / (z.norm(dim=1, keepdim=True) + 1e-8) # Use efficient k-NN search with cosine distance to find k nearest neighbors distances, _ = self._nbrs.kneighbors(z_norm.detach().cpu().numpy(), n_neighbors=self.k) # Convert cosine distances to similarities: sim = 1 - distance similarities = 1.0 - distances # Mean similarity over k nearest neighbors guidance = torch.from_numpy(similarities.mean(axis=1)).to(device).float() # higher guidance * energy = more in-distribution, so negate for outlier score return -(guidance * energy)