"""
.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
:alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
:alt: segmentation badge
.. autoclass:: pytorch_ood.detector.SHE
:members:
:inherited-members:
:show-inheritance:
"""
from typing import TypeVar, Callable
import torch
from pytorch_ood.api import RequiresFittingException
from torch import nn
from torch import Tensor
from torch.utils.data import DataLoader
import logging
from pytorch_ood.utils import extract_features, is_known, TensorBuffer
from ..api import FeaturesDetector, ModelNotSetException
Self = TypeVar("Self")
log = logging.getLogger(__name__)
[docs]
class SHE(FeaturesDetector):
"""
Implements Simplified Hopfield Energy from the paper
*Out-of-Distribution Detection based on In-Distribution Data Patterns Memorization with modern Hopfield Energy*
For each class, SHE estimates the mean feature vector :math:`S_i` of correctly classified instances.
For some new instances with predicted class :math:`\\hat{y}`, SHE then
uses the inner product :math:`f(x)^{\\top} S_{\\hat{y}}` as outlier score.
:see Paper: `OpenReview <https://openreview.net/pdf?id=KkazG4lgKL>`__
"""
requires_fit = True
def __init__(self, backbone: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor]):
"""
:param backbone: feature extractor
:param head: maps feature vectors to logits
"""
super(SHE, self).__init__()
self.backbone = backbone
self.head = head
self.patterns = None
self.is_fitted = False
[docs]
def predict(self, x: Tensor) -> Tensor:
"""
:param x: model inputs
"""
if self.backbone is None:
raise ModelNotSetException()
z = self.backbone(x)
return self.predict_features(z)
def predict_features(self, z: Tensor) -> Tensor:
"""
:param z: features as given by the model
"""
if self.head is None:
raise ModelNotSetException(msg="When using predict_features(), head must not be None")
if self.patterns is None:
raise RequiresFittingException()
y_hat = self.head(z).argmax(dim=1)
self.patterns = self.patterns.to(y_hat.device)
scores = torch.sum(torch.mul(z, self.patterns[y_hat]), dim=1)
return -scores
[docs]
def fit(self: Self, data_loader: DataLoader) -> Self:
"""
Extracts features and calculates mean patterns.
:param data_loader: data to fit
"""
device = self.device
if device is None:
device = "cpu"
log.warning(f"No device set. Will use '{device}'.")
self.to(device)
x, y = extract_features(data_loader, self.backbone, device=device)
return self.fit_features(x, y)
@torch.no_grad()
def _filter_correct_predictions(self, z: Tensor, y: Tensor, batch_size: int = 1024):
"""
:param z: a tensor of shape (N, D) or similar
:param y: labels of shape (N,)
:param batch_size: how many samples we process at a time
"""
device = self.device or z.device
buffer = TensorBuffer()
for start_idx in range(0, z.size(0), batch_size):
end_idx = start_idx + batch_size
z_batch = z[start_idx:end_idx].to(device)
y_batch = y[start_idx:end_idx].to(device)
y_hat_batch = self.head(z_batch).argmax(dim=1)
mask = y_hat_batch == y_batch
buffer.append("z", z_batch[mask])
buffer.append("y", y_hat_batch[mask])
return buffer["z"], buffer["y"]
[docs]
def fit_features(self: Self, z: Tensor, y: Tensor, batch_size: int = 1024) -> Self:
"""
Calculates mean patterns per class.
:param z: features to fit
:param y: labels
:param batch_size: how many samples we process at a time
"""
device = self.device or z.device
known = is_known(y)
if not known.any():
raise ValueError("No ID samples")
y = y[known]
z = z[known]
classes = y.unique()
# make sure all classes are present
assert len(classes) == classes.max().item() + 1
z, y = self._filter_correct_predictions(z, y, batch_size=batch_size)
m = []
for clazz in classes:
idx = y == clazz
if not idx.any():
raise ValueError(f"No correct predictions for class {clazz.item()}")
mav = z[idx].to(device).mean(dim=0)
m.append(mav)
self.patterns = torch.stack(m)
return self