Source code for pytorch_ood.detector.gradnorm

"""

.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
   :alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-no-red?style=flat-square
   :alt: segmentation badge

..  autoclass:: pytorch_ood.detector.GradNorm
    :members:
    :inherited-members:
    :show-inheritance:
    :exclude-members: fit
"""
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from typing import TypeVar, Callable

from ..api import Detector, ModelNotSetException

try:
    from torch.func import grad as _func_grad, vmap as _vmap, functional_call as _functional_call

    _TORCH_FUNC_AVAILABLE = True
except ImportError:
    _TORCH_FUNC_AVAILABLE = False

Self = TypeVar("Self")


[docs] class GradNorm(Detector): """ Detector from the paper *Gradients as a Measure of Uncertainty in Neural Networks*. For each input sample, computes the binary cross-entropy loss between logits and a "confounding label", which is a vector of all ones. Then, for each set of parameters in the model (as given by ``model.named_parameters()``), computes up the squared :math:`\\ell_2`-norm of the gradients of the loss w.r.t. that parameter. The outlier score is the sum of these squared norms. The idea is that higher gradient norms indicates that the model would require large parameter updates to accommodate the input, i.e., for such data, it is less familiar or more uncertain, and hence more likely to be OOD. .. note:: OpenOOD uses only the gradients of the final classification head, which makes this computationally cheaper. You can achieve something similar by setting ``param_filter``. Still, this method will compute gradients for all parameters unless you explicitly deactivate gradient calculation for parameters. For an example, see :doc:`here <auto_examples/detectors/gradnorm>` .. note:: On PyTorch ≥ 2.0, per-sample gradients are computed with ``torch.func.vmap`` + ``torch.func.grad`` in a single batched forward+backward pass. On PyTorch 1.x the original sequential loop over individual samples is used as a fallback. .. warning:: The paper's actual experiments (Section 4) concatenate the per-layer squared L2 norms into a feature vector and then **train a 2-layer FC binary classifier** on labeled ID and OOD gradient representations. The current implementation is a significant simplification: it sums all norms into a single scalar and uses it as a direct outlier score without any training. This simplification requires no OOD data but tends to perform poorly (AUROC ≈ 0.5) when ID and OOD datasets are of similar complexity, because the scalar sum loses the per-layer discriminative structure the classifier exploits. For an unsupervised gradient-based alternative see :class:`~pytorch_ood.detector.GradNormKL`. :see Paper: `ICIP <https://arxiv.org/abs/2008.08030v2>`__ """ def __init__(self, model: torch.nn.Module, param_filter: Callable[[str], bool] = None): """ :param model: A pre-trained classification model :param param_filter: Function which indicates whether a named parameter should be included in the scoring. If none give, all parameters will be used. """ if model is None: raise ModelNotSetException("Model must be provided.") def default_filter(x): return True self.param_filter = param_filter or default_filter self.model = model def fit(self, data_loader: DataLoader, **kwargs) -> Self: return self
[docs] def predict(self, x: Tensor) -> Tensor: """ Compute outlier scores from input batch. We will use the device of the model parameters for computations. On PyTorch ≥ 2.0, per-sample gradients are batched via ``torch.func``; on older versions a sequential loop is used. :param x: input, will be passed through network :return: vector of outlier scores """ if self.model is None: raise ModelNotSetException() device = next(self.model.parameters()).device x = x.to(device) if _TORCH_FUNC_AVAILABLE: return self._predict_batched(x) return self._predict_sequential(x)
def _predict_batched(self, x: Tensor) -> Tensor: """Vectorized per-sample gradients via torch.func (PyTorch ≥ 2.0).""" params = dict(self.model.named_parameters()) buffers = dict(self.model.named_buffers()) model = self.model param_filter = self.param_filter def loss_for_single(params, x_single): logits = _functional_call(model, (params, buffers), (x_single.unsqueeze(0),)) y_conf = torch.ones_like(logits) return F.binary_cross_entropy(logits.softmax(dim=1), y_conf, reduction="sum") with torch.enable_grad(): per_sample_grads = _vmap(_func_grad(loss_for_single), in_dims=(None, 0))(params, x) total_norms = x.new_zeros(x.shape[0]) for name, g in per_sample_grads.items(): if param_filter(name): total_norms = total_norms + (g**2).sum(dim=tuple(range(1, g.ndim))) return total_norms def _predict_sequential(self, x: Tensor) -> Tensor: """Per-sample gradients via serial backward passes (PyTorch < 2.0 fallback).""" device = x.device scores = [] for xi in x: with torch.enable_grad(): self.model.zero_grad() logits = self.model(xi.unsqueeze(0)) y_conf = torch.ones_like(logits, device=device) loss = F.binary_cross_entropy(logits.softmax(dim=1), y_conf, reduction="sum") loss.backward() total_norm = torch.tensor(0.0, device=device) for name, p in self.model.named_parameters(): if self.param_filter(name) and p.grad is not None: total_norm = total_norm + torch.sum(p.grad.detach() ** 2) scores.append(total_norm) return torch.stack(scores)