Source code for pytorch_ood.detector.mahalanobis

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: classification badge

..  autoclass:: pytorch_ood.detector.Mahalanobis
    :members:
    :inherited-members:
    :show-inheritance:
"""

import logging
import warnings
from typing import Callable, List, Optional, TypeVar

import torch
from torch import Tensor
from torch.autograd import Variable
from torch.utils.data import DataLoader

from ..api import FeaturesDetector, ModelNotSetException, RequiresFittingException
from ..utils import (
    TensorBuffer,
    contains_unknown,
    extract_features,
    is_known,
    is_unknown,
)

log = logging.getLogger(__name__)

Self = TypeVar("Self")


[docs] class Mahalanobis(FeaturesDetector): """ Implements the Mahalanobis Method from the paper *A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks*. This method calculates a class center :math:`\\mu_y` for each class, and a shared covariance matrix :math:`\\Sigma` from the data. The outlier scores are then calculated as .. math :: - \\max_k \\lbrace (f(x) - \\mu_k)^{\\top} \\Sigma^{-1} (f(x) - \\mu_k) \\rbrace Also uses ODIN preprocessing if the given :math:`\\epsilon > 0` :see Implementation: `GitHub <https://github.com/pokaxpoka/deep_Mahalanobis_detector>`__ :see Paper: `ArXiv <https://arxiv.org/abs/1807.03888>`__ """ requires_fit = True def __init__( self, model: Optional[Callable[[Tensor], Tensor]], eps: float = 0.002, norm_std: Optional[List] = None, ): """ :param model: the Neural Network, should output features. Can be ``None`` when using ``fit_features(...)`` and ``predict_features(...)`` directly. :param eps: magnitude for gradient based input preprocessing :param norm_std: Standard deviations for input normalization """ super(Mahalanobis, self).__init__() self.model = model self.mu: Tensor = None #: Centers self.cov: Tensor = None #: Covariance Matrix self.precision: Tensor = None #: Precision Matrix self.eps: float = eps #: epsilon self.norm_std = norm_std
[docs] def fit(self: Self, data_loader: DataLoader) -> Self: """ Fit parameters of the multi variate gaussian. :param data_loader: dataset to fit on. """ device = self.device if device is None: device = "cpu" log.warning(f"No device set. Will use '{device}'.") self.to(device) z, y = extract_features(data_loader, self.model, device) return self.fit_features(z, y)
[docs] def fit_features(self: Self, z: Tensor, y: Tensor) -> Self: """ Fit parameters of the multi variate gaussian. :param z: features :param y: class labels """ device = self.device or z.device y = y.to(device) log.debug("Calculating mahalanobis parameters.") classes = y.unique() # we assume here that all class 0 >= labels <= classes.max() exist assert len(classes) == classes.max().item() + 1 assert not contains_unknown(classes) n_classes = len(classes) self.mu = torch.zeros(size=(n_classes, z.shape[-1]), device=device) self.cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device) for clazz in range(n_classes): idxs = y.eq(clazz).to(z.device) assert idxs.sum() != 0 # we only move them to device after indexing to reduce ram usage. zs = z[idxs].to(device) self.mu[clazz] = zs.mean(dim=0) self.cov += (zs - self.mu[clazz]).T.mm(zs - self.mu[clazz]) self.cov += torch.eye(self.cov.shape[0], device=self.cov.device) * 1e-6 self.precision = torch.linalg.inv(self.cov) return self
def _calc_gaussian_scores(self, z: Tensor) -> Tensor: """ """ features = z.view(z.size(0), z.size(1), -1) features = torch.mean(features, 2) md_k = [] # calculate per class scores for clazz in range(self.n_classes): centered_z = features.data - self.mu[clazz] term_gau = -0.5 * torch.mm(torch.mm(centered_z, self.precision), centered_z.t()).diag() md_k.append(term_gau.view(-1, 1)) return torch.cat(md_k, 1) def predict_features(self, z: Tensor) -> Tensor: """ Calculates mahalanobis distance directly on features. ODIN preprocessing will not be applied. :param z: features, as given by the model. """ if self.mu is None: raise RequiresFittingException md_k = self._calc_gaussian_scores(z) score = -torch.max(md_k, dim=1).values return score
[docs] def predict(self, x: Tensor) -> Tensor: """ :param x: input tensor """ if self.model is None: raise ModelNotSetException if self.eps > 0: x = self._odin_preprocess(x, x.device) features = self.model(x) return self.predict_features(features)
def _odin_preprocess(self, x: Tensor, dev: str): """ NOTE: the original implementation uses mean over feature maps. here, we just flatten """ # does not work in inference mode, this sometimes collides with pytorch-lightning if torch.is_inference_mode_enabled(): warnings.warn("ODIN not compatible with inference mode. Will be deactivated.") with torch.inference_mode(False): if torch.is_inference(x): x = x.clone() with torch.enable_grad(): x = Variable(x, requires_grad=True) features = self.model(x) features = features.view(features.shape[0], -1) # flatten score = None for clazz in range(self.n_classes): centered_features = features.data - self.mu[clazz] term_gau = ( -0.5 * torch.mm( torch.mm(centered_features, self.precision), centered_features.t(), ).diag() ) if clazz == 0: score = term_gau.view(-1, 1) else: score = torch.cat((score, term_gau.view(-1, 1)), dim=1) # calculate gradient of inputs with respect to score of predicted class, # according to mahalanobis distance sample_pred = score.max(dim=1).indices batch_sample_mean = self.mu.index_select(0, sample_pred) centered_features = features - Variable(batch_sample_mean) pure_gau = ( -0.5 * torch.mm( torch.mm(centered_features, Variable(self.precision)), centered_features.t(), ).diag() ) loss = torch.mean(-pure_gau) loss.backward() gradient = torch.sign(x.grad.data) if self.norm_std: for i, std in enumerate(self.norm_std): gradient.index_copy_( 1, torch.LongTensor([i]).to(dev), gradient.index_select(1, torch.LongTensor([i]).to(dev)) / std, ) perturbed_x = x.data - self.eps * gradient return perturbed_x @property def n_classes(self): """ Number of classes the model is fitted for """ if self.mu is None: raise RequiresFittingException return self.mu.shape[0]