Source code for pytorch_ood.detector.mahalanobis

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: classification badge

..  autoclass:: pytorch_ood.detector.Mahalanobis
    :members:
    :inherited-members:
    :show-inheritance:
"""

import logging
import warnings
from typing import Callable, List, Optional, TypeVar

import torch
from torch import Tensor
from torch.autograd import Variable
from torch.utils.data import DataLoader

from ..api import (
    FeaturesDetector,
    GradientDetector,
    ModelNotSetException,
    RequiresFittingException,
)
from ..utils import (
    contains_unknown,
    extract_features,
)

log = logging.getLogger(__name__)

Self = TypeVar("Self")


[docs] class Mahalanobis(FeaturesDetector): """ Implements the Mahalanobis Method from the paper *A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks*. This method calculates a class center :math:`\\mu_y` for each class, and a shared covariance matrix :math:`\\Sigma` from the data. The outlier scores are then calculated as .. math :: - \\max_k \\lbrace (f(x) - \\mu_k)^{\\top} \\Sigma^{-1} (f(x) - \\mu_k) \\rbrace :see Implementation: `GitHub <https://github.com/pokaxpoka/deep_Mahalanobis_detector>`__ :see Paper: `ArXiv <https://arxiv.org/abs/1807.03888>`__ """ requires_fit = True def __init__( self, encoder: Optional[Callable[[Tensor], Tensor]], ): """ :param encoder: feature encoder. Can be ``None`` when using ``fit_features(...)`` and ``predict_features(...)`` directly. """ super(Mahalanobis, self).__init__() self.encoder = encoder self.mu: Tensor = None #: Centers self.cov: Tensor = None #: Covariance Matrix self.precision: Tensor = None #: Precision Matrix
[docs] def fit(self: Self, data_loader: DataLoader) -> Self: """ Fit parameters of the multi variate gaussian. :param data_loader: dataset to fit on. """ device = self.device if device is None: device = "cpu" log.warning(f"No device set. Will use '{device}'.") self.to(device) z, y = extract_features(data_loader, self.encoder, device) return self.fit_features(z, y)
[docs] def fit_features(self: Self, z: Tensor, y: Tensor) -> Self: """ Fit parameters of the multi variate gaussian. :param z: features :param y: class labels """ device = self.device or z.device z = z.detach().to(device) y = y.to(device) log.debug("Calculating mahalanobis parameters.") classes = y.unique() # we assume here that all class 0 >= labels <= classes.max() exist assert len(classes) == classes.max().item() + 1 assert not contains_unknown(classes) n_classes = len(classes) self.mu = torch.zeros(size=(n_classes, z.shape[-1]), device=device) self.cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device) for clazz in range(n_classes): idxs = y.eq(clazz) assert idxs.sum() != 0 zs = z[idxs] self.mu[clazz] = zs.mean(dim=0) self.cov += (zs - self.mu[clazz]).T.mm(zs - self.mu[clazz]) self.cov += torch.eye(self.cov.shape[0], device=self.cov.device) * 1e-6 self.precision = torch.linalg.inv(self.cov) return self
def _calc_gaussian_scores(self, z: Tensor) -> Tensor: """ """ features = z.view(z.size(0), z.size(1), -1) features = torch.mean(features, 2) md_k = [] # calculate per class scores for clazz in range(self.n_classes): centered_z = features.data - self.mu[clazz] term_gau = -0.5 * ((centered_z @ self.precision) * centered_z).sum(dim=1) md_k.append(term_gau.view(-1, 1)) return torch.cat(md_k, 1) @torch.no_grad() def predict_features(self, z: Tensor) -> Tensor: """ Calculates mahalanobis distance directly on features. ODIN preprocessing will not be applied. :param z: features, as given by the model. """ if self.mu is None: raise RequiresFittingException md_k = self._calc_gaussian_scores(z) score = -torch.max(md_k, dim=1).values return score
[docs] @torch.no_grad() def predict(self, x: Tensor) -> Tensor: """ :param x: input tensor """ if self.encoder is None: raise ModelNotSetException features = self.encoder(x) return self.predict_features(features)
@property def n_classes(self): """ Number of classes the model is fitted for """ if self.mu is None: raise RequiresFittingException return self.mu.shape[0]
[docs] class MahalanobisODIN(GradientDetector): """ Mahalanobis distance detector with ODIN input preprocessing. Combines the Mahalanobis distance from *A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks* with ODIN input perturbation from *Enhancing The Reliability of Out-of-distribution Image Detection in Neural Networks*. Adds gradient-guided input perturbation (FGSM-style) before computing the Mahalanobis distance. Requires gradient computation during prediction. :see Paper (Mahalanobis): `ArXiv <https://arxiv.org/abs/1807.03888>`__ :see Paper (ODIN): `ArXiv <https://arxiv.org/abs/1706.02690>`__ """ requires_fit = True def __init__( self, encoder: Optional[Callable[[Tensor], Tensor]], eps: float = 0.002, norm_std: Optional[List] = None, ): """ :param encoder: feature encoder. Can be ``None`` when using ``fit_features(...)`` and ``predict_features(...)`` directly. :param eps: magnitude for gradient based input preprocessing :param norm_std: Standard deviations for input normalization """ super(MahalanobisODIN, self).__init__() self._base = Mahalanobis(encoder) self.eps = eps self.norm_std = norm_std
[docs] def fit(self: Self, data_loader: DataLoader) -> Self: """ Fit the underlying Mahalanobis detector. :param data_loader: dataset to fit on. """ self._base.fit(data_loader) return self
[docs] def fit_features(self: Self, z: Tensor, y: Tensor) -> Self: """ Fit the underlying Mahalanobis detector on features. :param z: features :param y: class labels """ self._base.fit_features(z, y) return self
[docs] def predict(self, x: Tensor) -> Tensor: """ Apply ODIN input perturbation and compute Mahalanobis distance. :param x: input tensor """ x = self._odin_preprocess(x, x.device) return self._base.predict(x)
def _odin_preprocess(self, x: Tensor, dev: str): """ ODIN input preprocessing guided by Mahalanobis distance. """ if torch.is_inference_mode_enabled(): warnings.warn("ODIN not compatible with inference mode. Will be deactivated.") with torch.inference_mode(False): if torch.is_inference(x): x = x.clone() with torch.enable_grad(): x = Variable(x, requires_grad=True) features = self._base.encoder(x) features = features.view(features.shape[0], -1) score = None for clazz in range(self._base.n_classes): centered_features = features.data - self._base.mu[clazz] term_gau = ( -0.5 * torch.mm( torch.mm(centered_features, self._base.precision), centered_features.t(), ).diag() ) if clazz == 0: score = term_gau.view(-1, 1) else: score = torch.cat((score, term_gau.view(-1, 1)), dim=1) sample_pred = score.max(dim=1).indices batch_sample_mean = self._base.mu.index_select(0, sample_pred) centered_features = features - Variable(batch_sample_mean) pure_gau = ( -0.5 * torch.mm( torch.mm(centered_features, Variable(self._base.precision)), centered_features.t(), ).diag() ) loss = torch.mean(-pure_gau) loss.backward() gradient = torch.sign(x.grad.data) if self.norm_std: for i, std in enumerate(self.norm_std): gradient.index_copy_( 1, torch.LongTensor([i]).to(dev), gradient.index_select(1, torch.LongTensor([i]).to(dev)) / std, ) perturbed_x = x.data - self.eps * gradient return perturbed_x