Source code for pytorch_ood.api

import logging
from abc import ABC, abstractmethod
from typing import TypeVar

import torch
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader

Self = TypeVar("Self")
log = logging.getLogger(__name__)


class RequiresFittingException(Exception):
    """
    Raised when predict is called on a detector that has not been fitted.
    """

    def __init__(self, msg="You have to call fit() before predict()"):
        super(RequiresFittingException, self).__init__(msg)


class ModelNotSetException(ValueError):
    """
    Raised when predict() is called but no model was given.
    """

    def __init__(self, msg="When using predict(), model must not be None"):
        super(ModelNotSetException, self).__init__(msg)


[docs] class Detector(ABC): """ Root public API for out-of-distribution detectors. Every detector supports ``predict(x)`` on raw model inputs and a generic ``fit(data_loader)`` entry point. Semantic subclasses refine this contract with alternate representation-specific methods such as ``predict_logits(...)`` or ``predict_features(...)``. """ requires_fit = False #: Whether ``fit(...)`` must be called before scoring. @staticmethod def _move_value_to_device(value, device: torch.device): """ Move common detector-owned state to the given device. """ if isinstance(value, Module): value.to(device) return value if isinstance(value, Tensor): return value.to(device) owner = getattr(value, "__self__", None) if isinstance(owner, Module): owner.to(device) return value if isinstance(value, list): return [Detector._move_value_to_device(v, device) for v in value] if isinstance(value, tuple): return tuple(Detector._move_value_to_device(v, device) for v in value) if isinstance(value, dict): for key, inner in value.items(): value[key] = Detector._move_value_to_device(inner, device) return value return value @staticmethod def _move_tensor_arguments_to_device(args, kwargs, device: torch.device): """ Move tensor-valued positional and keyword arguments to ``device``. """ moved_args = tuple( value.to(device) if isinstance(value, Tensor) else value for value in args ) moved_kwargs = { key: value.to(device) if isinstance(value, Tensor) else value for key, value in kwargs.items() } return moved_args, moved_kwargs @staticmethod def _infer_value_device(value): """ Infer the device of common detector-owned state, if any. """ if isinstance(value, Module): try: return next(value.parameters()).device except StopIteration: try: return next(value.buffers()).device except StopIteration: return None if isinstance(value, Tensor): return value.device owner = getattr(value, "__self__", None) if isinstance(owner, Module): return Detector._infer_value_device(owner) if isinstance(value, (list, tuple)): for inner in value: device = Detector._infer_value_device(inner) if device is not None: return device return None if isinstance(value, dict): for inner in value.values(): device = Detector._infer_value_device(inner) if device is not None: return device return None return None @property def device(self): """ The device of the detector's owned torch state, if one can be inferred. """ for value in vars(self).values(): device = self._infer_value_device(value) if device is not None: return device return getattr(self, "_device", None)
[docs] def to(self: Self, device) -> Self: """ Move detector-owned modules and tensor state to ``device``. This is a detector-level analogue of ``nn.Module.to(...)``. It moves modules, tensors, and common container-valued state stored on the detector itself. :param device: target torch device :return: self """ device = torch.device(device) self._device = device for attr, value in list(vars(self).items()): if attr == "_device": continue setattr(self, attr, self._move_value_to_device(value, device)) return self
def __call__(self, *args, **kwargs) -> Tensor: """ Forwards to predict """ return self.predict(*args, **kwargs)
[docs] def fit(self: Self, data_loader: DataLoader) -> Self: """ Fit the detector to a dataset. Some methods require this. :param data_loader: dataset to fit on. This is usually the training dataset. :raise ModelNotSetException: if model was not set """ if self.requires_fit: raise NotImplementedError( f"{type(self).__name__} requires fitting and must implement fit()." ) return self
[docs] @abstractmethod def predict(self, x: Tensor) -> Tensor: """ Calculates outlier scores. Inputs will be passed through the model. :param x: batch of data :return: outlier scores for points :raise RequiresFitException: if detector has to be fitted to some data :raise ModelNotSetException: if model was not set """ raise NotImplementedError
[docs] class LogitsDetector(Detector): """ Base class for detectors whose alternate public API consumes logits. Subclasses implement ``predict_logits(...)`` and optionally ``fit_logits(...)``. The default ``predict(x)`` and ``fit(data_loader)`` implementations forward raw inputs through ``self.model`` to obtain logits first. """ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) predict_logits = cls.__dict__.get("predict_logits") if predict_logits is not None: cls.predict_logits = cls._wrap_representation_method(predict_logits) @staticmethod def _wrap_representation_method(method): def wrapped(self, *args, **kwargs): device = self.device if device is not None: args, kwargs = self._move_tensor_arguments_to_device(args, kwargs, device) return method(self, *args, **kwargs) wrapped.__name__ = method.__name__ wrapped.__doc__ = method.__doc__ wrapped.__qualname__ = method.__qualname__ return wrapped
[docs] def predict(self, x: Tensor) -> Tensor: """ Apply the model and forward its logits to ``predict_logits(...)``. :param x: input batch :return: outlier scores """ if not hasattr(self, "model") or self.model is None: raise ModelNotSetException detector_device = self.device if detector_device is not None: x = x.to(detector_device) return self.predict_logits(self.model(x))
[docs] def fit(self: Self, data_loader: DataLoader) -> Self: """ Extract logits from a loader and forward them to ``fit_logits(...)``. :param data_loader: loader to extract logits from """ if not hasattr(self, "model") or self.model is None: raise ModelNotSetException device = self.device if device is None: device = "cpu" log.warning(f"No device set. Will use '{device}'.") self.to(device) from .utils import extract_features z, y = extract_features(data_loader=data_loader, model=self.model, device=device) return self.fit_logits(z, y)
[docs] def fit_logits(self: Self, logits: Tensor, y: Tensor) -> Self: """ Fit the detector directly on logits. :param logits: training logits to use for fitting. :param y: corresponding class labels. """ raise NotImplementedError
[docs] def predict_logits(self, logits: Tensor) -> Tensor: """ Calculates outlier scores directly from logits. :param logits: batch of logits :return: outlier scores for points """ raise NotImplementedError
[docs] class FeaturesDetector(Detector): """ Base class for detectors whose alternate public API consumes one feature tensor. Subclasses implement ``predict_features(...)`` and, when fitting is required, ``fit_features(...)``. """ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) predict_features = cls.__dict__.get("predict_features") if predict_features is not None: cls.predict_features = cls._wrap_representation_method(predict_features) @staticmethod def _wrap_representation_method(method): def wrapped(self, *args, **kwargs): device = self.device if device is not None: args, kwargs = self._move_tensor_arguments_to_device(args, kwargs, device) return method(self, *args, **kwargs) wrapped.__name__ = method.__name__ wrapped.__doc__ = method.__doc__ wrapped.__qualname__ = method.__qualname__ return wrapped
[docs] def fit_features(self: Self, x: Tensor, y: Tensor) -> Self: """ Fit the detector directly on feature tensors. :param x: training features to use for fitting :param y: corresponding class labels """ raise NotImplementedError
[docs] def predict_features(self, x: Tensor) -> Tensor: """ Calculate outlier scores directly from feature tensors. :param x: batch of features :return: outlier scores for points """ raise NotImplementedError
[docs] class FeatureMapsDetector(Detector): """ Base class for detectors whose alternate public API consumes feature maps. Subclasses implement ``predict_feature_maps(...)`` and, when fitting is required, ``fit_feature_maps(...)``. """ def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) predict_feature_maps = cls.__dict__.get("predict_feature_maps") if predict_feature_maps is not None: cls.predict_feature_maps = cls._wrap_representation_method(predict_feature_maps) @staticmethod def _wrap_representation_method(method): def wrapped(self, *args, **kwargs): device = self.device if device is not None: args, kwargs = self._move_tensor_arguments_to_device(args, kwargs, device) return method(self, *args, **kwargs) wrapped.__name__ = method.__name__ wrapped.__doc__ = method.__doc__ wrapped.__qualname__ = method.__qualname__ return wrapped
[docs] def fit_feature_maps(self: Self, feature_maps: Tensor, y: Tensor) -> Self: """ Fit the detector directly on feature maps. :param feature_maps: training feature maps to use for fitting. :param y: corresponding class labels. """ raise NotImplementedError
[docs] def predict_feature_maps(self, feature_maps: Tensor) -> Tensor: """ Calculates outlier scores directly from feature maps. :param feature_maps: batch of feature maps :return: outlier scores for points """ raise NotImplementedError
[docs] class StructuredDetector(Detector): """ Base class for detectors whose alternate public API consumes structured inputs. This is used for detectors whose non-model interface is not well described by a single tensor family, for example lists of per-layer features or mixed inputs such as logits plus feature maps. """
[docs] def fit_structured(self: Self, *args, **kwargs) -> Self: """ Fit the detector directly on structured intermediate representations. """ raise NotImplementedError
[docs] def predict_structured(self, *args, **kwargs) -> Tensor: """ Calculates outlier scores directly from structured intermediate representations. """ raise NotImplementedError