Source code for pytorch_ood.detector.openmax.torch

"""
Torch wrapper for a numpy implementation of openmax.
"""
import logging
from typing import Optional, TypeVar
from .numpy import OpenMax as NumpyOpenMax

import torch
from torch import Tensor
from torch.nn import Module

from ...api import LogitsDetector, ModelNotSetException

log = logging.getLogger(__name__)
Self = TypeVar("Self")


[docs] class OpenMax(LogitsDetector): """ Implementation of the OpenMax Layer as proposed in the paper *Towards Open Set Deep Networks*. The method determines a center :math:`\\mu_y` for each class in the logits space of a model, and then creates a statistical model of the distances of correct classified inputs. It uses extreme value theory to detect outliers by fitting a weibull function to the tail of the distance distribution. We use the pseudo-activation of the *unknown* class as outlier score. :see Paper: `ArXiv <https://arxiv.org/abs/1511.06233>`__ :see Implementation: `GitHub <https://github.com/abhijitbendale/OSDN>`__ """ requires_fit = True def __init__( self, model: Optional[Module], tailsize: int = 25, alpha: int = 10, euclid_weight: float = 1.0, ): """ :param model: neural network, assumed to output logits. Can be ``None`` when using ``fit_logits(...)`` and ``predict_logits(...)`` directly. :param tailsize: length of the tail to fit the distribution to :param alpha: number of class activations to revise :param euclid_weight: weight for the Euclidean distance. """ self.model = model self._openmax = NumpyOpenMax(tailsize=tailsize, alpha=alpha, euclid_weight=euclid_weight)
[docs] def fit_logits(self: Self, logits: Tensor, y: Tensor) -> Self: """ Determines parameters of the weibull functions for each class. :param logits: logits given by the model :param y: class labels :return: """ logits, y = logits.cpu().numpy(), y.cpu().numpy() self._openmax.fit(logits, y) return self
[docs] def predict(self, x: Tensor) -> Tensor: """ :param x: input, will be passed through the model to get logits """ if self.model is None: raise ModelNotSetException with torch.no_grad(): logits = self.model(x) return self.predict_logits(logits)
def predict_logits(self, logits: Tensor) -> Tensor: """ :param logits: logits given by model """ logits = logits.cpu().numpy() return torch.tensor(self._openmax.predict(logits)[:, 0])