"""
.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
:alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
:alt: segmentation badge
.. image:: https://img.shields.io/badge/AI_Coded-yes-blue?style=flat-square
:alt: slop-badge
.. autoclass:: pytorch_ood.detector.PNML
:members:
:inherited-members:
:show-inheritance:
"""
import logging
from typing import Callable, Optional, TypeVar
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from ..api import FeaturesDetector, ModelNotSetException, RequiresFittingException
from ..utils import extract_features, is_known
log = logging.getLogger(__name__)
Self = TypeVar("Self")
[docs]
class PNML(FeaturesDetector):
"""
Implements the pNML regret from
*Single Layer Predictive Normalized Maximum Likelihood for Out-of-Distribution Detection*.
Uses normalized penultimate-layer features together with the classifier probabilities
to compute the pNML regret. Higher regret indicates a more likely OOD sample.
For normalized training features :math:`X`, their Moore-Penrose pseudoinverse
:math:`X^+`, normalized test feature :math:`z`, classifier probabilities
:math:`p_i(z)`, and
:math:`\\kappa(z) = \\frac{z^\\top X^+ X^{+\\top} z}{1 + z^\\top X^+ X^{+\\top} z}`,
the detector scores a sample by
.. math::
R(z) = \\frac{1}{\\log C}
\\log \\sum_{i=1}^{C}
\\frac{p_i(z)}{p_i(z) + (1 - p_i(z)) p_i(z)^{\\kappa(z)}}.
Intuitively, the score is low when a sample is well supported by the training feature geometry
and the classifier is confident, and high when the sample falls in weakly supported directions.
:see Paper:
`ArXiv <https://arxiv.org/abs/2110.09246>`__
:see Implementation:
`GitHub <https://github.com/kobybibas/pnml_ood_detection>`__
"""
requires_fit = True
def __init__(
self,
encoder: Optional[Callable[[Tensor], Tensor]],
head: Optional[Callable[[Tensor], Tensor]],
eps: float = 1e-12,
):
"""
:param encoder: feature encoder mapping inputs to penultimate-layer features
:param head: classification head mapping normalized features to logits
:param eps: numerical stability constant for probability clamping
"""
self.encoder = encoder
self.head = head
self.eps = eps
self._feature_projector = None
self._log_num_classes = None
[docs]
def fit(self: Self, data_loader: DataLoader) -> Self:
"""
Extract features and fit the pNML detector.
:param data_loader: data loader with training data
"""
if self.encoder is None:
raise ModelNotSetException()
device = self.device
if device is None:
device = "cpu"
log.warning(f"No device set. Will use '{device}'.")
self.to(device)
z, y = extract_features(data_loader, self.encoder, device)
return self.fit_features(z, y)
[docs]
def fit_features(self: Self, z: Tensor, labels: Tensor) -> Self:
"""
Fit pNML directly on penultimate-layer features.
Labels are only used to filter out OOD-marked samples when present.
:param z: training features
:param labels: class labels
"""
known = is_known(labels)
if not known.any():
raise ValueError("No ID samples found.")
target_device = self.device or z.device
z = z[known].detach().to(target_device).float()
z = torch.nn.functional.normalize(z, p=2, dim=1)
x_pinv = torch.linalg.pinv(z)
self._feature_projector = x_pinv @ x_pinv.T
self._log_num_classes = None
return self
[docs]
def predict(self, x: Tensor) -> Tensor:
"""
:param x: input tensor, will be passed through the backbone
"""
if self.encoder is None:
raise ModelNotSetException()
z = self.encoder(x)
return self.predict_features(z)
@torch.no_grad()
def predict_features(self, z: Tensor) -> Tensor:
"""
Calculate outlier scores using the normalized pNML regret.
:param z: penultimate-layer features
:return: outlier scores (higher = more OOD)
"""
if self.head is None:
raise ModelNotSetException(msg="When using predict_features(), head must not be None")
if self._feature_projector is None:
raise RequiresFittingException()
device = self._feature_projector.device
z = z.detach().to(device).float()
z = torch.nn.functional.normalize(z, p=2, dim=1)
probs = torch.softmax(self.head(z), dim=1)
probs = probs.clamp(min=self.eps, max=1.0 - self.eps)
x_proj = (z @ self._feature_projector * z).sum(dim=1)
x_t_g = x_proj / (1.0 + x_proj)
regret_terms = probs / (probs + (1.0 - probs) * probs.pow(x_t_g.unsqueeze(1)))
regret = torch.log(regret_terms.sum(dim=1))
if self._log_num_classes is None:
self._log_num_classes = torch.log(torch.tensor(float(probs.shape[1]), device=device))
return regret / self._log_num_classes