Source code for pytorch_ood.detector.gradnormkl

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: segmentation badge
.. image:: https://img.shields.io/badge/AI_Coded-yes-blue?style=flat-square
   :alt: slop-badge

..  autoclass:: pytorch_ood.detector.GradNormKL
    :members:
    :inherited-members:
    :show-inheritance:
    :exclude-members: fit
"""
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from typing import TypeVar, Callable

from ..api import Detector, ModelNotSetException

try:
    from torch.func import grad as _func_grad, vmap as _vmap, functional_call as _functional_call

    _TORCH_FUNC_AVAILABLE = True
except ImportError:
    _TORCH_FUNC_AVAILABLE = False

Self = TypeVar("Self")


[docs] class GradNormKL(Detector): """ Detector from the paper *On the Importance of Gradients for Detecting Distributional Shifts in the Wild*. For each input sample, computes the KL divergence between the softmax output and a uniform distribution (implemented via binary cross-entropy with a uniform confounding label of :math:`1/C` per class). The outlier score is the **negated** :math:`\\ell_1`-norm of the gradients of this loss w.r.t. the selected model parameters. The key insight is that the gradient w.r.t. the logits simplifies to :math:`\\text{softmax}(z) - 1/C`, which is zero when the model predicts a uniform distribution and grows as the prediction becomes more peaked. For in-distribution inputs the model is typically more confident (larger gradient norm) than for OOD inputs, so the negated norm gives higher scores to OOD samples, consistent with the convention that higher outlier scores indicate OOD data. .. note:: The paper recommends using only the gradients of the final classification head (last FC layer) for computational efficiency. You can achieve this by setting ``param_filter`` and disabling gradient computation for the backbone via ``model.requires_grad_(False); model.fc.requires_grad_(True)``. .. note:: On PyTorch ≥ 2.0, per-sample gradients are computed with ``torch.func.vmap`` + ``torch.func.grad`` in a single batched forward+backward pass. On PyTorch 1.x the original sequential loop over individual samples is used as a fallback. :see Paper: `NeurIPS <https://arxiv.org/abs/2110.00218>`__ """ def __init__(self, model: torch.nn.Module, param_filter: Callable[[str], bool] = None): """ :param model: A pre-trained classification model. :param param_filter: Function indicating whether a named parameter should be included in the scoring. If ``None``, all parameters are used. """ if model is None: raise ModelNotSetException("Model must be provided.") def default_filter(x): return True self.param_filter = param_filter or default_filter self.model = model def fit(self, data_loader: DataLoader, **kwargs) -> Self: return self
[docs] def predict(self, x: Tensor) -> Tensor: """ Compute outlier scores for an input batch. Uses the device of the model parameters for all computations. On PyTorch ≥ 2.0, per-sample gradients are batched via ``torch.func``; on older versions a sequential loop is used. :param x: input tensor, will be passed through the network :return: vector of outlier scores (higher = more likely OOD) """ if self.model is None: raise ModelNotSetException() device = next(self.model.parameters()).device x = x.to(device) if _TORCH_FUNC_AVAILABLE: return self._predict_batched(x) return self._predict_sequential(x)
def _predict_batched(self, x: Tensor) -> Tensor: """Vectorized per-sample gradients via torch.func (PyTorch ≥ 2.0).""" params = dict(self.model.named_parameters()) buffers = dict(self.model.named_buffers()) model = self.model param_filter = self.param_filter def loss_for_single(params, x_single): logits = _functional_call(model, (params, buffers), (x_single.unsqueeze(0),)) C = logits.shape[1] y_uniform = torch.ones_like(logits) / C return F.binary_cross_entropy(logits.softmax(dim=1), y_uniform, reduction="sum") with torch.enable_grad(): per_sample_grads = _vmap(_func_grad(loss_for_single), in_dims=(None, 0))(params, x) total_norms = x.new_zeros(x.shape[0]) for name, g in per_sample_grads.items(): if param_filter(name): total_norms = total_norms + g.abs().sum(dim=tuple(range(1, g.ndim))) return -total_norms def _predict_sequential(self, x: Tensor) -> Tensor: """Per-sample gradients via serial backward passes (PyTorch < 2.0 fallback).""" device = x.device scores = [] for xi in x: with torch.enable_grad(): self.model.zero_grad() logits = self.model(xi.unsqueeze(0)) C = logits.shape[1] y_uniform = torch.ones_like(logits) / C loss = F.binary_cross_entropy(logits.softmax(dim=1), y_uniform, reduction="sum") loss.backward() total_norm = torch.tensor(0.0, device=device) for name, p in self.model.named_parameters(): if self.param_filter(name) and p.grad is not None: total_norm = total_norm + p.grad.detach().abs().sum() scores.append(-total_norm) return torch.stack(scores)