Source code for pytorch_ood.detector.she

"""
.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: classification badge

..  autoclass:: pytorch_ood.detector.SHE
    :members:
"""
from typing import Callable, TypeVar

import torch
from torch import Tensor
from torch.utils.data import DataLoader

from pytorch_ood.utils import extract_features, is_known

from ..api import Detector, ModelNotSetException

Self = TypeVar("Self")


[docs] class SHE(Detector): """ Implements Simplified Hopfield Energy from the paper *Out-of-Distribution Detection based on In-Distribution Data Patterns Memorization with modern Hopfield Energy* For each class, SHE estimates the mean feature vector :math:`S_i` of correctly classified instances. For some new instances with predicted class :math:`\\hat{y}`, SHE then uses the inner product :math:`f(x)^{\\top} S_{\\hat{y}}` as outlier score. :see Paper: `OpenReview <https://openreview.net/pdf?id=KkazG4lgKL>`__ """ def __init__(self, model: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor]): """ :param model: feature extractor :param head: maps feature vectors to logits """ super(SHE, self).__init__() self.model = model self.head = head self.patterns = None self.is_fitted = False
[docs] def predict(self, x: Tensor) -> Tensor: """ :param x: model inputs """ if self.model is None: raise ModelNotSetException z = self.model(x) return self.predict_features(z)
[docs] def predict_features(self, z: Tensor) -> Tensor: """ :param z: features as given by the model """ y_hat = self.head(z).argmax(dim=1) scores = torch.sum(torch.mul(z, self.patterns[y_hat]), dim=1) return -scores
[docs] def fit(self: Self, loader: DataLoader, device: str = "cpu") -> Self: """ Extracts features and calculates mean patterns. :param loader: data to fit :param device: device to use for computations """ x, y = extract_features(loader, self.model, device=device) return self.fit_features(x.to(device), y.to(device))
[docs] def fit_features(self: Self, z: Tensor, y: Tensor) -> Self: """ Calculates mean patterns per class. :param z: features to fit :param y: labels """ known = is_known(y) if not known.any(): raise ValueError("No IN samples") y = y[known] z = z[known] classes = y.unique() # assume all classes are present assert len(classes) == classes.max().item() + 1 # select correctly classified y_hat = self.head(z).argmax(dim=1) z = z[y_hat == y] y = y[y_hat == y] m = [] for clazz in classes: mav = z[y == clazz].mean(dim=0) m.append(mav) self.patterns = torch.stack(m) return self