Source code for pytorch_ood.detector.mmahalanobis

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: classification badge

..  autoclass:: pytorch_ood.detector.MultiMahalanobis
    :members:
"""

import logging
from typing import List, TypeVar

import torch
from torch import Tensor
from torch.nn import Module, Sequential
from torch.utils.data import DataLoader

from ..api import Detector, ModelNotSetException, RequiresFittingException
from ..utils import contains_unknown, extract_feature_avg

log = logging.getLogger(__name__)

Self = TypeVar("Self")


[docs] class MultiMahalanobis(Detector): """ Implements the Mahalanobis Method from the paper *A Simple Unified Framework for Detecting Out-of-Distribution Samples and Adversarial Attacks* which supports several layers. For each of the given :math:`i` layers, the method calculates a class center :math:`\\mu_{iy}` for each class, and a shared covariance matrix :math:`\\Sigma_i` from the data. The per-layer outlier scores are calculated as .. math :: M_i(x) = - \\max_k \\lbrace (f_i(x) - \\mu_{ik})^{\\top} \\Sigma_i^{-1} (f_i(x) - \\mu_{ik}) \\rbrace The final outlier score is the sum of all scores, weighted by :math:`\\alpha`. Example code is provided :doc:`here <auto_examples/detectors/mmahalanobis>` .. note :: This does not yet support ODIN preprocessing. Also, the :math:`\\alpha` values have to be determined manually. :see Implementation: `GitHub <https://github.com/pokaxpoka/deep_Mahalanobis_detector>`__ :see Paper: `ArXiv <https://arxiv.org/abs/1807.03888>`__ """ def __init__(self, model: List[Module], alpha: List[float] = None): """ :param model: the neural network layers :math:`f_1(\\cdot),...,f_n(\\cdot)`, output of one will be used as input to the next. :param alpha: weighting of the individual layers. Defaults to uniform weighting. """ super(MultiMahalanobis, self).__init__() if len(model) == 0: raise ValueError("No modules given") self.model = model # parameters of Gaussians self.mu: List[Tensor] = [] #: Centers self.cov: List[Tensor] = [] #: Covariance Matrices self.precision: List[Tensor] = [] #: Precision Matrices if alpha is None: # uniform weighting by default if alpha is not given alpha = [1.0] * len(model) self.alpha = alpha #: Per-layer weighting factors
[docs] def fit(self: Self, data_loader: DataLoader, device: str = None) -> Self: """ Fit one gaussian to the features of each layer. Will average over feature maps. :param data_loader: dataset to fit on. :param device: device to use :return: """ if device is None: # use device of first layer device = list(self.model[0].parameters())[0].device log.warning(f"No device given. Will use '{device}'.") if isinstance(self.model, torch.nn.Module): log.debug(f"Moving model to {device}") self.model.to(device) zs = [] for layer_idx in range(len(self.model)): # NOTE: this could be done more efficiently model = Sequential(*self.model[: layer_idx + 1]) log.debug(f"Extracting for layer {layer_idx}") z, y = extract_feature_avg(data_loader, model, device) log.debug(f"Extracted {z.shape} features for {y.shape[0]} samples.") zs.append(z) return self.fit_features(zs, y, device)
[docs] def fit_features(self: Self, zs: List[Tensor], y: Tensor, device: str = None) -> Self: """ Fit parameters of the multi variate gaussians. :param zs: list of features for each layer :param y: class labels :param device: device to use :return: """ if device is None: device = zs[0].device log.warning(f"No device given. Will use '{device}'.") y = y.to(device) classes = y.unique() # we assume here that all class 0 >= labels <= classes.max() exist assert len(classes) == classes.max().item() + 1 assert not contains_unknown(classes) n_classes = len(classes) for layer_idx, z in enumerate(zs): org_device = z.device z = z.to(device) log.debug( f"Calculating mahalanobis parameters for layer {layer_idx} with {n_classes=} {z.shape=} {y.shape=}" ) mu = torch.zeros(size=(n_classes, z.shape[-1]), device=device) cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device) for clazz in range(n_classes): idxs = y.eq(clazz) assert idxs.sum() != 0 z_c = z[idxs] mu[clazz] = z_c.mean(dim=0) cov += (z_c - mu[clazz]).T.mm(z_c - mu[clazz]) cov += torch.eye(cov.shape[0], device=cov.device) * 1e-6 precision = torch.linalg.inv(cov) self.mu.append(mu) self.cov.append(cov) self.precision.append(precision) z = z.to(org_device) return self
def _calc_gaussian_scores(self, z: Tensor, layer_idx) -> Tensor: """ """ features = z.view(z.size(0), z.size(1), -1) features = torch.mean(features, 2) md_k = [] # calculate per class scores for clazz in range(self.n_classes): centered_z = features.data - self.mu[layer_idx][clazz] term_gau = ( -0.5 * torch.mm(torch.mm(centered_z, self.precision[layer_idx]), centered_z.t()).diag() ) md_k.append(term_gau.view(-1, 1)) return torch.cat(md_k, 1)
[docs] def predict_features(self, zs: List[Tensor], device=None) -> Tensor: """ Calculates mahalanobis distance directly on features. ODIN preprocessing will not be applied. :param zs: list of per-layer features :param device: device to use for computations """ if not self.mu: raise RequiresFittingException if not device: device = zs[0].shape batch_size = zs[0].shape[0] scores = torch.empty(batch_size, len(zs), device=device) for layer_idx, z in enumerate(zs): org_device = z.device z = z.to(device) md_k = self._calc_gaussian_scores(z, layer_idx) z = z.to(org_device) score = -torch.max(md_k, dim=1).values scores[:, layer_idx] = self.alpha[layer_idx] * score return scores.sum(dim=1)
[docs] def predict(self, x: Tensor) -> Tensor: """ :param x: input tensor """ if not self.model: raise ModelNotSetException if not self.mu: raise RequiresFittingException zs = [] device = x.device for layer_idx in range(len(self.model)): # NOTE: This could be done more efficiently model = Sequential(*self.model[: layer_idx + 1]) z = model(x) # TODO: use mean over feature planes? z = z.mean(dim=(2, 3)).view(z.shape[0], -1) zs.append(z) return self.predict_features(zs, device=device)
@property def n_classes(self): """ Number of classes the model is fitted for """ if not self.mu: raise RequiresFittingException return self.mu[0].shape[0]