Source code for pytorch_ood.detector.energy

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-yes-brightgreen?style=flat-square
   :alt: classification badge

..  autoclass:: pytorch_ood.detector.EnergyBased
    :members:
    :exclude-members: fit, fit_features
"""

from typing import Optional, TypeVar

from torch import Tensor, logsumexp
from torch.nn import Module

from ..api import Detector, ModelNotSetException

Self = TypeVar("Self")


[docs] class EnergyBased(Detector): """ Implements the Energy Score of *Energy-based Out-of-distribution Detection*. This methods calculates the negative energy for a vector of logits. This value can be used as outlier score. .. math:: E(x) = -T \\log{\\sum_i e^{f_i(x)/T}} where :math:`f_i(x)` indicates the :math:`i^{th}` logit value predicted by :math:`f`. :see Paper: `NeurIPS <https://proceedings.neurips.cc/paper/2020/file/f5496252609c43eb8a3d147ab9b9c006-Paper.pdf>`__ :see Implementation: `GitHub <https://github.com/wetliu/energy_ood>`__ """ def fit(self: Self, *args, **kwargs) -> Self: """ Not required. """ return self def fit_features(self: Self, *args, **kwargs) -> Self: """ Not required. """ return self def __init__(self, model: Module, t: Optional[float] = 1.0): """ :param t: Temperature value :math:`T`. Default is 1. """ super(EnergyBased, self).__init__() self.t: float = t #: Temperature self.model = model
[docs] def predict(self, x: Tensor) -> Tensor: """ Calculate negative energy for inputs :param x: input tensor, will be passed through model :return: Energy score """ if self.model is None: raise ModelNotSetException return self.score(self.model(x), t=self.t)
[docs] def predict_features(self, logits: Tensor) -> Tensor: """ :param logits: logits given by the model """ return EnergyBased.score(logits, t=self.t)
[docs] @staticmethod def score(logits: Tensor, t: Optional[float] = 1.0) -> Tensor: """ :param logits: logits of input :param t: temperature value """ return -t * logsumexp(logits / t, dim=1)