"""
.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
:alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
:alt: classification badge
.. autoclass:: pytorch_ood.detector.Mahalanobis
:members:
"""
import logging
import warnings
from typing import Callable, List, Optional, TypeVar
import torch
from torch import Tensor
from torch.autograd import Variable
from torch.utils.data import DataLoader
from ..api import Detector, ModelNotSetException, RequiresFittingException
from ..utils import (
TensorBuffer,
contains_unknown,
extract_features,
is_known,
is_unknown,
)
log = logging.getLogger(__name__)
Self = TypeVar("Self")
[docs]
class Mahalanobis(Detector):
"""
Implements the Mahalanobis Method from the paper *A Simple Unified Framework for Detecting
Out-of-Distribution Samples and Adversarial Attacks*.
This method calculates a class center :math:`\\mu_y` for each class,
and a shared covariance matrix :math:`\\Sigma` from the data.
The outlier scores are then calculated as
.. math :: - \\max_k \\lbrace (f(x) - \\mu_k)^{\\top} \\Sigma^{-1} (f(x) - \\mu_k) \\rbrace
Also uses ODIN preprocessing.
:see Implementation: `GitHub <https://github.com/pokaxpoka/deep_Mahalanobis_detector>`__
:see Paper: `ArXiv <https://arxiv.org/abs/1807.03888>`__
"""
def __init__(
self,
model: Callable[[Tensor], Tensor],
eps: float = 0.002,
norm_std: Optional[List] = None,
):
"""
:param model: the Neural Network, should output features
:param eps: magnitude for gradient based input preprocessing
:param norm_std: Standard deviations for input normalization
"""
super(Mahalanobis, self).__init__()
self.model = model
self.mu: Tensor = None #: Centers
self.cov: Tensor = None #: Covariance Matrix
self.precision: Tensor = None #: Precision Matrix
self.eps: float = eps #: epsilon
self.norm_std = norm_std
[docs]
def fit(self: Self, data_loader: DataLoader, device: str = None) -> Self:
"""
Fit parameters of the multi variate gaussian.
:param data_loader: dataset to fit on.
:param device: device to use
"""
if device is None:
device = list(self.model.parameters())[0].device
log.warning(f"No device given. Will use '{device}'.")
if isinstance(self.model, torch.nn.Module):
log.debug(f"Moving model to device {device}")
self.model.to(device)
z, y = extract_features(data_loader, self.model, device)
return self.fit_features(z, y, device)
[docs]
def fit_features(self: Self, z: Tensor, y: Tensor, device: str = None) -> Self:
"""
Fit parameters of the multi variate gaussian.
:param z: features
:param y: class labels
:param device: device to use
"""
if device is None:
device = z.device
log.warning(f"No device given. Will use '{device}'.")
y = y.to(device)
log.debug("Calculating mahalanobis parameters.")
classes = y.unique()
# we assume here that all class 0 >= labels <= classes.max() exist
assert len(classes) == classes.max().item() + 1
assert not contains_unknown(classes)
n_classes = len(classes)
self.mu = torch.zeros(size=(n_classes, z.shape[-1]), device=device)
self.cov = torch.zeros(size=(z.shape[-1], z.shape[-1]), device=device)
for clazz in range(n_classes):
idxs = y.eq(clazz).to(z.device)
assert idxs.sum() != 0
# we only move them to device after indexing to reduce ram usage.
zs = z[idxs].to(device)
self.mu[clazz] = zs.mean(dim=0)
self.cov += (zs - self.mu[clazz]).T.mm(zs - self.mu[clazz])
self.cov += torch.eye(self.cov.shape[0], device=self.cov.device) * 1e-6
self.precision = torch.linalg.inv(self.cov)
return self
def _calc_gaussian_scores(self, z: Tensor) -> Tensor:
""" """
features = z.view(z.size(0), z.size(1), -1)
features = torch.mean(features, 2)
md_k = []
# calculate per class scores
for clazz in range(self.n_classes):
centered_z = features.data - self.mu[clazz]
term_gau = -0.5 * torch.mm(torch.mm(centered_z, self.precision), centered_z.t()).diag()
md_k.append(term_gau.view(-1, 1))
return torch.cat(md_k, 1)
[docs]
def predict_features(self, z: Tensor) -> Tensor:
"""
Calculates mahalanobis distance directly on features.
ODIN preprocessing will not be applied.
:param z: features, as given by the model.
"""
if self.mu is None:
raise RequiresFittingException
md_k = self._calc_gaussian_scores(z)
score = -torch.max(md_k, dim=1).values
return score
[docs]
def predict(self, x: Tensor) -> Tensor:
"""
:param x: input tensor
"""
if self.model is None:
raise ModelNotSetException
if self.eps > 0:
x = self._odin_preprocess(x, x.device)
features = self.model(x)
return self.predict_features(features)
def _odin_preprocess(self, x: Tensor, dev: str):
"""
NOTE: the original implementation uses mean over feature maps. here, we just flatten
"""
# does not work in inference mode, this sometimes collides with pytorch-lightning
if torch.is_inference_mode_enabled():
warnings.warn("ODIN not compatible with inference mode. Will be deactivated.")
with torch.inference_mode(False):
if torch.is_inference(x):
x = x.clone()
with torch.enable_grad():
x = Variable(x, requires_grad=True)
features = self.model(x)
features = features.view(features.shape[0], -1) # flatten
score = None
for clazz in range(self.n_classes):
centered_features = features.data - self.mu[clazz]
term_gau = (
-0.5
* torch.mm(
torch.mm(centered_features, self.precision),
centered_features.t(),
).diag()
)
if clazz == 0:
score = term_gau.view(-1, 1)
else:
score = torch.cat((score, term_gau.view(-1, 1)), dim=1)
# calculate gradient of inputs with respect to score of predicted class,
# according to mahalanobis distance
sample_pred = score.max(dim=1).indices
batch_sample_mean = self.mu.index_select(0, sample_pred)
centered_features = features - Variable(batch_sample_mean)
pure_gau = (
-0.5
* torch.mm(
torch.mm(centered_features, Variable(self.precision)),
centered_features.t(),
).diag()
)
loss = torch.mean(-pure_gau)
loss.backward()
gradient = torch.sign(x.grad.data)
if self.norm_std:
for i, std in enumerate(self.norm_std):
gradient.index_copy_(
1,
torch.LongTensor([i]).to(dev),
gradient.index_select(1, torch.LongTensor([i]).to(dev)) / std,
)
perturbed_x = x.data - self.eps * gradient
return perturbed_x
@property
def n_classes(self):
"""
Number of classes the model is fitted for
"""
if self.mu is None:
raise RequiresFittingException
return self.mu.shape[0]