Source code for pytorch_ood.detector.vra

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-yes-brightgreen?style=flat-square
   :alt: segmentation badge
.. image:: https://img.shields.io/badge/AI_Coded-yes-blue?style=flat-square
   :alt: slop-badge

..  autoclass:: pytorch_ood.detector.VRA
    :members:
    :inherited-members:
    :show-inheritance:

"""

import logging
from typing import Callable, TypeVar

import numpy as np
import torch.nn
from torch import Tensor
from torch.utils.data import DataLoader

from pytorch_ood.utils import extract_features, is_known

from ..api import FeatureMapsDetector, ModelNotSetException, RequiresFittingException
from .energy import EnergyBased

log = logging.getLogger(__name__)
Self = TypeVar("Self")


[docs] class VRA(FeatureMapsDetector): """ Implements VRA from the paper *Variational Rectified Activation for Out-of-Distribution Detection*. VRA is a two-sided version of ReAct that clips activations both above and below using percentile thresholds learned from In-Distribution data, then scores the result with an outlier detector (Energy-Based by default). Unlike ReAct, which only clips activations from above at a fixed threshold, VRA learns per-dimension lower and upper clipping bounds from the training data using configurable percentiles. Example Code: .. code :: python model = WideResNet() detector = VRA( backbone=model.features, head=model.fc, ) detector.fit(train_loader) scores = detector(images) :see Paper: `ArXiv <https://arxiv.org/abs/2302.11716>`__ """ requires_fit = True def __init__( self, backbone: Callable[[Tensor], Tensor], head: Callable[[Tensor], Tensor], lower_percentile: float = 1.0, upper_percentile: float = 99.0, detector: Callable[[Tensor], Tensor] = None, ): """ :param backbone: first part of the model, should output feature maps :param head: second part of the model used after clipping, should output logits :param lower_percentile: lower percentile for clipping threshold (default 1.0) :param upper_percentile: upper percentile for clipping threshold (default 99.0) :param detector: detector that maps outputs to outlier scores. Default is energy based. """ self.backbone = backbone self.head = head self.lower_percentile = lower_percentile self.upper_percentile = upper_percentile self.detector = detector or EnergyBased.score self._lower_threshold = None self._upper_threshold = None
[docs] def predict(self, x: Tensor) -> Tensor: """ :param x: input, will be passed through the network """ if self.backbone is None: raise ModelNotSetException() if self._lower_threshold is None: raise RequiresFittingException() z = self.backbone(x) return self.predict_feature_maps(z)
def predict_feature_maps(self, x: Tensor) -> Tensor: """ :param x: features from the backbone """ if self.head is None: raise ModelNotSetException() if self._lower_threshold is None: raise RequiresFittingException() x = self._clip(x) x = self.head(x) return self.detector(x)
[docs] def fit_feature_maps(self: Self, z: Tensor, y: Tensor) -> Self: """ Calculate per-dimension clipping thresholds from In-Distribution features. OOD inputs will be ignored. :param z: features :param y: labels """ known = is_known(y) if not known.any(): raise ValueError("No ID data") z = z[known].detach().cpu().numpy() self._lower_threshold = torch.from_numpy( np.percentile(z, self.lower_percentile, axis=0).astype(np.float32) ) self._upper_threshold = torch.from_numpy( np.percentile(z, self.upper_percentile, axis=0).astype(np.float32) ) log.info( f"Lower threshold range: [{self._lower_threshold.min():.2f}, {self._lower_threshold.max():.2f}]" ) log.info( f"Upper threshold range: [{self._upper_threshold.min():.2f}, {self._upper_threshold.max():.2f}]" ) return self
[docs] def fit(self: Self, data_loader: DataLoader) -> Self: """ Extract features and calculate clipping thresholds. OOD inputs will be ignored. :param data_loader: data loader to extract features from """ if self.backbone is None: raise ModelNotSetException() device = self.device if device is None: device = "cpu" log.warning(f"No device set. Will use '{device}'.") self.to(device) z, y = extract_features(data_loader, self.backbone, device=device) self.fit_feature_maps(z, y) return self
def _clip(self, z: Tensor) -> Tensor: lower = self._lower_threshold.to(z.device) upper = self._upper_threshold.to(z.device) return z.clip(min=lower, max=upper)