"""
.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
:alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
:alt: classification badge
.. autoclass:: pytorch_ood.detector.ViM
:members:
:inherited-members:
:show-inheritance:
"""
import logging
from typing import Callable, Optional, TypeVar
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from ..api import FeaturesDetector, ModelNotSetException, RequiresFittingException
from ..utils import extract_features
log = logging.getLogger(__name__)
Self = TypeVar("Self")
[docs]
class ViM(FeaturesDetector):
"""
Implements Virtual Logit Matching (ViM) from the paper *ViM: Out-Of-Distribution with Virtual-logit Matching*.
:see Paper:
`ArXiv <https://arxiv.org/abs/2203.10807>`__
:see Implementation:
`GitHub <https://github.com/haoqiwang/vim/>`__
.. note::
Requires PyTorch ≥ 1.9 (``torch.linalg``).
"""
requires_fit = True
def __init__(
self,
model: Optional[Callable[[torch.Tensor], torch.Tensor]],
d: int,
w: torch.Tensor,
b: torch.Tensor,
):
"""
:param model: neural network to use, is assumed to output features. Can be
``None`` when using ``fit_features(...)`` and ``predict_features(...)`` directly.
:param d: dimensionality of the principal subspace
:param w: weights :math:`W` of the last layer of the network
:param b: biases :math:`b` of the last layer of the network
"""
super(ViM, self).__init__()
self.model = model
self.n_dim = d
w = w.detach().cpu().float()
b = b.detach().cpu().float()
self.w = w # (C, D)
self.b = b # (C,)
self.u = -(torch.linalg.pinv(w) @ b) # (D,) new origin
self.principal_subspace: Optional[Tensor] = None
self.alpha: Optional[float] = None #: the computed :math:`\alpha` value
def _get_logits(self, features: Tensor) -> Tensor:
"""
Calculates logits from features.
"""
return features @ self.w.T + self.b
[docs]
def predict(self, x: Tensor) -> Tensor:
"""
:param x: model input, will be passed through neural network
"""
if self.model is None:
raise ModelNotSetException
if self.principal_subspace is None or self.alpha is None:
raise RequiresFittingException()
with torch.no_grad():
features = self.model(x)
return self.predict_features(features)
def __repr__(self):
return f"ViM(d={self.n_dim})"
[docs]
def fit(self: Self, data_loader: DataLoader) -> Self:
"""
Extracts features and logits, computes principle subspace and alpha. Ignores OOD samples.
:param data_loader: dataset to fit on
"""
if self.model is None:
raise ModelNotSetException
device = self.device
if device is None:
device = "cpu"
log.warning(f"No device set. Will use '{device}'.")
self.to(device)
features, labels = extract_features(data_loader, self.model, device)
return self.fit_features(features, labels)
def predict_features(self, x: Tensor) -> Tensor:
"""
:param x: features as given by the model
"""
device = self.w.device
x = x.detach().to(device).float()
logits = self._get_logits(x) # (N, C)
# Project centered features onto the null subspace and take L2 norm
x_p_t = (x - self.u) @ self.principal_subspace # (N, D-n_dim)
vlogit = x_p_t.norm(dim=-1) * self.alpha # (N,)
# Clip for numerical stability: float32 easily overflows in logsumexp
energy = torch.logsumexp(logits.clamp(-100, 100), dim=-1) # (N,)
score = -vlogit + energy
return -score
[docs]
def fit_features(self: Self, features: Tensor, labels: Tensor) -> Self:
"""
Extracts features and logits, computes principle subspace and alpha. Ignores OOD samples.
:param features: features
:param labels: class labels
:return:
"""
features = features.cpu().float()
if features.shape[1] < self.n_dim:
n = features.shape[1] // 2
log.warning(
f"{features.shape[1]=} is smaller than {self.n_dim=}. Will be adjusted to {n}"
)
self.n_dim = n
logits = self._get_logits(features) # (N, C)
log.info("Computing principal space ...")
X = features - self.u # (N, D) centered features
# Empirical covariance (assume_centered=True → MLE: divide by n)
cov = (X.T @ X) / X.shape[0] # (D, D)
# Eigendecomposition of the symmetric covariance matrix.
# torch.linalg.eigh returns eigenvalues in ascending order with
# corresponding eigenvectors as columns.
eig_vals, eig_vecs = torch.linalg.eigh(cov) # vals: (D,), vecs: (D, D)
# Select the null subspace: the (D - n_dim) eigenvectors that correspond
# to the *smallest* eigenvalues (i.e. directions least explained by the
# training data). With ascending eigh output these are the first columns.
k = eig_vecs.shape[1] - self.n_dim
self.principal_subspace = eig_vecs[:, :k].contiguous() # (D, D-n_dim)
log.info("Computing alpha ...")
x_p_t = X @ self.principal_subspace # (N, D-n_dim)
vlogits = x_p_t.norm(dim=-1) # (N,)
self.alpha = (logits.max(dim=-1).values.mean() / vlogits.mean()).item()
log.info(f"{self.alpha=:.4f}")
return self