Source code for pytorch_ood.detector.klmatching

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: classification badge

..  autoclass:: pytorch_ood.detector.KLMatching
    :members:
    :inherited-members:
    :show-inheritance:

"""

import logging
from typing import Optional, TypeVar

import torch
from torch import Tensor
from torch.nn import Module, Parameter, ParameterDict
from ..api import LogitsDetector, ModelNotSetException, RequiresFittingException

log = logging.getLogger()

Self = TypeVar("Self")


[docs] class KLMatching(LogitsDetector): """ Implements KL-Matching from the paper *Scaling Out-of-Distribution Detection for Real-World Settings*. For each class, an typical posterior distribution :math:`d_y = \\mathbb{E}_{x \\sim \\mathcal{X}_{val}}[p(y \\vert x)]` is estimated, where :math:`y` is the class with the maximum posterior :math:`y = \\arg\\max_y p(y \\vert x)`, as predicted by the model. Note that the method does not require class labels for the validation set. During evaluation, the KL-Divergence between the observed and the typical posterior :math:`D_{KL}[p(y \\vert x) \\Vert d_y]` is used as outlier score. This method can also be applied to multi-class settings. :see Paper: `ArXiv <https://arxiv.org/abs/1911.11132>`__ """ requires_fit = True def __init__(self, model: Optional[Module]): """ :param model: neural network, is assumed to output logits. Can be ``None`` when using ``fit_logits(...)`` and ``predict_logits(...)`` directly. """ super(KLMatching, self).__init__() self.model = model self.dists: ParameterDict = ParameterDict() #: Typical posteriors per class
[docs] def fit_logits(self: Self, logits: Tensor, labels: Tensor) -> Self: """ Estimates typical distributions for each class. Ignores OOD samples. :param logits: logits :param labels: class labels """ device = self.device or logits.device probabilities = logits.softmax(dim=1) for label in labels.unique(): log.debug(f"Fitting class {label}") d_k = probabilities[labels == label].to(device).mean(dim=0) self.dists[str(label.item())] = Parameter(d_k) return self
def predict_logits(self, logits: Tensor) -> Tensor: """ :param logits: logits predicted by the model """ p = logits.softmax(dim=1) return self._score_probabilities(p) def _score_probabilities(self, p: Tensor) -> Tensor: """ Score already-computed posterior probabilities. :param p: probabilities predicted by the model """ device = p.device predictions = p.argmax(dim=1) scores = torch.empty(size=(p.shape[0],), device=device) for label in predictions.unique(): if str(label.item()) not in self.dists: raise ValueError(f"Label {label.item()} not fitted.") dist = self.dists[str(label.item())] class_p = p[predictions == label] class_d = dist.unsqueeze(0).repeat(class_p.shape[0], 1) d_kl = (class_p * (class_p / class_d).log()).sum(dim=1) scores[predictions == label] = d_kl return scores
[docs] def predict(self, x: Tensor) -> Tensor: """ Calculates KL-Divergence between predicted posteriors and typical posteriors. :param x: input tensor, will be passed through model :return: Outlier scores """ if len(self.dists) == 0: raise RequiresFittingException("KL-Matching has to be fitted on validation data.") if self.model is None: raise ModelNotSetException # we move the dict with the typical posteriors to the same device as the input # this might be not desirable in some cases, but avoids errors device = x.device self.dists.to(device) logits = self.model(x) return self.predict_logits(logits)