Source code for pytorch_ood.benchmark.base

import re
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple, Union, overload

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_ood.api import Detector, FeaturesDetector, LogitsDetector
from pytorch_ood.detector.mahalanobis import Mahalanobis
from pytorch_ood.utils import OODMetrics, TensorBuffer

_CACHE_VERSION = 1


def _sanitize_cache_token(value: str) -> str:
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", value).strip("_") or "cache"


[docs] class Benchmark(ABC): """ Base class for Benchmarks """
[docs] @abstractmethod def train_set(self) -> Dataset: """ Training dataset """
[docs] @abstractmethod def test_sets(self, known=True, unknown=True) -> List[Dataset]: """ List of the different test datasets. If known and unknown are true, each dataset contains ID and OOD data. :param known: include ID :param unknown: include OOD """ pass
def _ensure_cache_state(self): if not hasattr(self, "_representation_cache"): self._representation_cache = {} if not hasattr(self, "_cache_warnings_shown"): self._cache_warnings_shown = set() def _warn_once(self, key: str, message: str) -> None: self._ensure_cache_state() if key in self._cache_warnings_shown: return warnings.warn(message, UserWarning, stacklevel=3) self._cache_warnings_shown.add(key) @staticmethod def _normalize_detectors( detector: Union[Detector, Sequence[Detector]], ) -> Tuple[List[Detector], bool]: if isinstance(detector, Detector): return [detector], False if isinstance(detector, Sequence): detectors = list(detector) if not detectors: raise ValueError("At least one detector must be provided") if not all(isinstance(item, Detector) for item in detectors): raise TypeError("All elements must be Detector instances") return detectors, True raise TypeError("detector must be a Detector or a sequence of Detector instances") @staticmethod def _get_logits_producer(detector: LogitsDetector): return getattr(detector, "model", None) @staticmethod def _get_features_producer(detector: FeaturesDetector): for attr in ("model", "encoder", "backbone"): if hasattr(detector, attr): producer = getattr(detector, attr) if producer is not None: return producer return None @staticmethod def _producer_token(producer) -> str: if producer is None: return "none" return f"{producer.__class__.__module__}.{producer.__class__.__qualname__}" def _memory_cache_key( self, split: str, dataset_name: str, representation: str, producer, cache_key: Optional[str], ) -> Tuple[str, str, str, int, Optional[str]]: return split, dataset_name, representation, id(producer), cache_key def _cache_file_path( self, cache_dir: str, split: str, dataset_name: str, representation: str, producer, cache_key: str, ) -> Path: producer_token = _sanitize_cache_token(self._producer_token(producer)) dataset_token = _sanitize_cache_token(dataset_name) split_token = _sanitize_cache_token(split) representation_token = _sanitize_cache_token(representation) cache_key_token = _sanitize_cache_token(cache_key) filename = ( f"{cache_key_token}_{split_token}_{dataset_token}_" f"{representation_token}_{producer_token}.pt" ) return Path(cache_dir) / filename def _load_disk_cache( self, cache_dir: Optional[str], split: str, dataset_name: str, representation: str, producer, cache_key: Optional[str], ) -> Optional[Dict]: if cache_dir is None or cache_key is None: return None path = self._cache_file_path( cache_dir=cache_dir, split=split, dataset_name=dataset_name, representation=representation, producer=producer, cache_key=cache_key, ) if not path.exists(): return None payload = torch.load(path, map_location="cpu") metadata = payload.get("metadata", {}) expected = { "cache_version": _CACHE_VERSION, "split": split, "dataset_name": dataset_name, "representation": representation, "cache_key": cache_key, "producer_token": self._producer_token(producer), } for key, value in expected.items(): if metadata.get(key) != value: return None return payload def _save_disk_cache( self, payload: Dict, cache_dir: Optional[str], split: str, dataset_name: str, representation: str, producer, cache_key: Optional[str], ) -> None: if cache_dir is None or cache_key is None: return path = self._cache_file_path( cache_dir=cache_dir, split=split, dataset_name=dataset_name, representation=representation, producer=producer, cache_key=cache_key, ) path.parent.mkdir(parents=True, exist_ok=True) torch.save(payload, path) @staticmethod def _extract_representation( data_loader: DataLoader, producer, device: str, representation: str, ) -> Dict: buffer = TensorBuffer(device="cpu") with torch.no_grad(): for x, y in data_loader: x = x.to(device) z = producer(x) z = z.view(z.shape[0], -1) buffer.append(representation, z) buffer.append("label", y) return { representation: buffer.get(representation), "label": buffer.get("label"), } def _get_representation_cache( self, split: str, dataset_name: str, representation: str, producer, data_loader: DataLoader, device: str, persist_memory: bool, cache_dir: Optional[str], cache_key: Optional[str], local_cache: Dict, ) -> Dict: self._ensure_cache_state() memory_key = self._memory_cache_key( split=split, dataset_name=dataset_name, representation=representation, producer=producer, cache_key=cache_key, ) if memory_key in local_cache: return local_cache[memory_key] if persist_memory and memory_key in self._representation_cache: payload = self._representation_cache[memory_key] local_cache[memory_key] = payload return payload payload = self._load_disk_cache( cache_dir=cache_dir, split=split, dataset_name=dataset_name, representation=representation, producer=producer, cache_key=cache_key, ) if payload is None: data = self._extract_representation( data_loader=data_loader, producer=producer, device=device, representation=representation, ) payload = { "metadata": { "cache_version": _CACHE_VERSION, "split": split, "dataset_name": dataset_name, "representation": representation, "cache_key": cache_key, "producer_token": self._producer_token(producer), "num_samples": int(data["label"].shape[0]), }, "data": data, } self._save_disk_cache( payload=payload, cache_dir=cache_dir, split=split, dataset_name=dataset_name, representation=representation, producer=producer, cache_key=cache_key, ) if persist_memory: self._representation_cache[memory_key] = payload local_cache[memory_key] = payload return payload @staticmethod def _supports_cached_logits(detector: Detector) -> bool: return ( isinstance(detector, LogitsDetector) and getattr(detector, "model", None) is not None ) @staticmethod def _supports_cached_features(detector: Detector) -> bool: if not isinstance(detector, FeaturesDetector): return False if isinstance(detector, Mahalanobis) and detector.eps > 0: return False return True @staticmethod def _evaluate_raw(detector: Detector, data_loader: DataLoader, device: str) -> Dict: metrics = OODMetrics() for x, y in data_loader: y = y.to(device) scores = detector(x.to(device)) metrics.update(scores, y.to(scores.device)) return metrics.compute() @staticmethod def _evaluate_logits(detector: LogitsDetector, payload: Dict) -> Dict: metrics = OODMetrics() logits = payload["data"]["logits"] labels = payload["data"]["label"] scores = detector.predict_logits(logits) metrics.update(scores, labels.to(scores.device)) return metrics.compute() @staticmethod def _evaluate_features(detector: FeaturesDetector, payload: Dict) -> Dict: metrics = OODMetrics() features = payload["data"]["features"] labels = payload["data"]["label"] scores = detector.predict_features(features) metrics.update(scores, labels.to(scores.device)) return metrics.compute() def _evaluate_single_detector( self, detector: Detector, dataset_name: str, data_loader: DataLoader, device: str, persist_memory: bool, cache_dir: Optional[str], cache_key: Optional[str], local_cache: Dict, ) -> Dict: detector = detector.to(device) if self._supports_cached_logits(detector): producer = self._get_logits_producer(detector) payload = self._get_representation_cache( split="eval", dataset_name=dataset_name, representation="logits", producer=producer, data_loader=data_loader, device=device, persist_memory=persist_memory, cache_dir=cache_dir, cache_key=cache_key, local_cache=local_cache, ) return self._evaluate_logits(detector, payload) if self._supports_cached_features(detector): producer = self._get_features_producer(detector) if producer is not None: payload = self._get_representation_cache( split="eval", dataset_name=dataset_name, representation="features", producer=producer, data_loader=data_loader, device=device, persist_memory=persist_memory, cache_dir=cache_dir, cache_key=cache_key, local_cache=local_cache, ) return self._evaluate_features(detector, payload) return self._evaluate_raw(detector, data_loader, device) @overload def evaluate( self, detector: Detector, loader_kwargs: Optional[Dict] = None, device: str = "cpu", cache: bool = False, cache_dir: Optional[str] = None, cache_key: Optional[str] = None, ) -> List[Dict]: ... @overload def evaluate( self, detector: Sequence[Detector], loader_kwargs: Optional[Dict] = None, device: str = "cpu", cache: bool = False, cache_dir: Optional[str] = None, cache_key: Optional[str] = None, ) -> List[Dict]: ...
[docs] def evaluate( self, detector: Union[Detector, Sequence[Detector]], loader_kwargs: Optional[Dict] = None, device: str = "cpu", cache: bool = False, cache_dir: Optional[str] = None, cache_key: Optional[str] = None, ) -> List[Dict]: """ Evaluate one detector or a list of detectors on all benchmark datasets. When several logits detectors or pooled-feature detectors are evaluated together, this method can reuse cached intermediate representations instead of recomputing model outputs for every detector. If ``cache=True``, those representations are also kept on the benchmark instance and reused across later ``evaluate(...)`` calls. If ``cache_dir`` is given, cached tensors are additionally persisted to disk. Disk-backed cache reuse is keyed only by user-provided ``cache_key`` and lightweight metadata, so cache correctness is the caller's responsibility. :param detector: detector instance or a sequence of detectors :param loader_kwargs: keyword arguments forwarded to the data loader :param device: device to move inputs and detectors to :param cache: keep cached representations on the benchmark instance :param cache_dir: optional directory for file-backed caches :param cache_key: user-supplied cache key used for disk cache reuse :return: benchmark results. For multiple detectors, each result includes a ``Detector`` field with the detector class name. """ detectors, many = self._normalize_detectors(detector) if loader_kwargs is None: loader_kwargs = {} persist_memory = cache or cache_dir is not None if cache_dir is not None and cache_key is None: self._warn_once( "cache_dir_without_key", "File-backed benchmark caching was requested without a cache_key. " "Disk cache reuse is disabled; only in-memory caching will be used.", ) cache_dir = None elif cache_dir is not None: self._warn_once( f"cache_key_responsibility:{cache_key}", "Benchmark cache reuse is keyed by the user-supplied cache_key and " "lightweight metadata only. Make sure your cache_key changes when " "the model, weights, or transforms change.", ) metrics = [] for dataset_name, dataset in zip(self.ood_names, self.test_sets()): loader = DataLoader(dataset=dataset, **loader_kwargs) local_cache = {} for current_detector in detectors: result = self._evaluate_single_detector( detector=current_detector, dataset_name=dataset_name, data_loader=loader, device=device, persist_memory=persist_memory, cache_dir=cache_dir, cache_key=cache_key, local_cache=local_cache, ) result.update({"Dataset": dataset_name}) if many: result.update({"Detector": type(current_detector).__name__}) metrics.append(result) return metrics