CIFAR 10 Cached Baseline

Benchmark code for CIFAR-10 that reuses cached logits and pooled features for fitting detectors that expose the standard representation APIs. At evaluation time, cached logits are used for true logits detectors, while feature detectors that operate directly on pooled features can use the cached feature representation. Other detectors still use their regular predict() interface so their full model-time behavior is preserved.

 15 from copy import deepcopy
 16 import gc
 17 import pandas as pd  # additional dependency, used here for convenience
 18 import torch
 19 from torch import nn
 20 from torch.utils.data import DataLoader
 21 from torchvision.datasets import CIFAR10, CIFAR100, FashionMNIST, MNIST
 22 from tqdm.auto import tqdm  # additional dependency, used here for convenience
 23
 24 from pytorch_ood.dataset.img import (
 25     LSUNCrop,
 26     LSUNResize,
 27     Places365,
 28     Textures,
 29     TinyImageNetCrop,
 30     TinyImageNetResize,
 31 )
 32 from pytorch_ood.api import FeaturesDetector, LogitsDetector
 33 from pytorch_ood.detector import (
 34     ASH,
 35     DICE,
 36     EnergyBased,
 37     Entropy,
 38     GEN,
 39     GMM,
 40     GradNorm,
 41     GradNormKL,
 42     Gram,
 43     KLMatching,
 44     KNN,
 45     Mahalanobis,
 46     MaxLogit,
 47     MaxSoftmax,
 48     MultiMahalanobis,
 49     NACUE,
 50     ODIN,
 51     RMD,
 52     RankFeat,
 53     SHE,
 54     ViM,
 55     fDBD,
 56 )
 57 from pytorch_ood.model import WideResNet
 58 from pytorch_ood.utils import (
 59     OODMetrics,
 60     TensorBuffer,
 61     ToUnknown,
 62     extract_features,
 63     fix_random_seed,
 64 )
 65
 66 device = "cuda:0"
 67
 68 fix_random_seed(123)
 69
 70
 71 def cache_representations(
 72     data_loader: DataLoader, model: nn.Module, feature_model: nn.Module, device: str
 73 ):
 74     """
 75     Cache logits, pooled features, and labels for all samples in a loader.
 76     Unlike ``extract_features(...)``, this keeps OOD samples and labels.
 77     """
 78     print(f"Extracting cached logits/features from {len(data_loader.dataset)} samples")
 79     buffer = TensorBuffer()
 80
 81     with torch.no_grad():
 82         for x, y in data_loader:
 83             x = x.to(device)
 84             buffer.append("logits", model(x).view(x.shape[0], -1))
 85             buffer.append("features", feature_model(x).view(x.shape[0], -1))
 86             buffer.append("label", y)
 87
 88     return {
 89         "logits": buffer.get("logits"),
 90         "features": buffer.get("features"),
 91         "labels": buffer.get("label"),
 92     }
 93
 94
 95 def fit_detector(
 96     detector, train_loader: DataLoader, train_cache: dict, train_labels: torch.Tensor, device: str
 97 ):
 98     """
 99     Fit using cached representations for detectors with a standard semantic
100     interface. Otherwise fall back to ``fit(...)``.
101     """
102     if detector.requires_fit and isinstance(detector, LogitsDetector):
103         detector.fit_logits(train_cache["logits"], train_labels)
104         return
105
106     if detector.requires_fit and isinstance(detector, FeaturesDetector):
107         detector.fit_features(train_cache["features"], train_labels)
108         return
109
110     if detector.requires_fit:
111         detector.to(device)
112         detector.fit(train_loader)
113
114
115 def evaluate_detector(detector, data_loader: DataLoader, eval_cache: dict, device: str) -> dict:
116     """
117     Evaluate using cached logits for logits detectors and cached pooled
118     features for feature detectors whose ``predict(x)`` does not add extra
119     input preprocessing. Otherwise fall back to ``predict(x)``.
120     """
121     metrics = OODMetrics()
122     detector.to(device)
123
124     if isinstance(detector, LogitsDetector):
125         scores = detector.predict_logits(eval_cache["logits"])
126         metrics.update(scores, eval_cache["labels"].to(scores.device))
127         return metrics.compute()
128
129     if isinstance(detector, FeaturesDetector):
130         if isinstance(detector, Mahalanobis) and detector.eps > 0:
131             for x, y in tqdm(data_loader, leave=False):
132                 metrics.update(detector(x.to(device)), y.to(device))
133         else:
134             scores = detector.predict_features(eval_cache["features"])
135             metrics.update(scores, eval_cache["labels"].to(scores.device))
136
137         return metrics.compute()
138
139     for x, y in tqdm(data_loader, leave=False):
140         metrics.update(detector(x.to(device)), y.to(device))
141
142     return metrics.compute()

Setup preprocessing

147 trans = WideResNet.transform_for("cifar10-pt")
148 norm_std = WideResNet.norm_std_for("cifar10-pt")

Setup datasets

153 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
154
155 ood_datasets = [
156     Textures,
157     TinyImageNetCrop,
158     TinyImageNetResize,
159     LSUNCrop,
160     LSUNResize,
161     Places365,
162     CIFAR100,
163     MNIST,
164     FashionMNIST,
165 ]
166 datasets = {}
167 for ood_dataset in ood_datasets:
168     dataset_out_test = ood_dataset(
169         root="data", transform=trans, target_transform=ToUnknown(), download=True
170     )
171     test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=12)
172     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create model

176 print("STAGE 1: Creating a Model")
177 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)

Stage 2: Create detectors

181 print("STAGE 2: Creating OOD Detectors")
182 detectors = {}
183
184 detectors["KNN"] = KNN(model.features)
185 detectors["GMM"] = GMM(model.features)
186 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
187
188 detectors["ASH"] = ASH(backbone=model.features_before_pool, head=model.forward_from_before_pool)
189 detectors["RankFeat"] = RankFeat(
190     backbone=model.features_before_pool, head=model.forward_from_before_pool
191 )
192
193 model_gn = deepcopy(model)
194 model_gn.requires_grad_(False)
195 model_gn.fc.requires_grad_(True)
196 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
197
198 model_gnkl = deepcopy(model)
199 model_gnkl.requires_grad_(False)
200 model_gnkl.fc.requires_grad_(True)
201 detectors["GradNormKL"] = GradNormKL(model_gnkl, param_filter=lambda name: name.startswith("fc"))
202
203 detectors["Entropy"] = Entropy(model)
204 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
205 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
206 detectors["Mahalanobis"] = Mahalanobis(model.features)
207
208 detectors["KLMatching"] = KLMatching(model)
209 detectors["SHE"] = SHE(model.features, model.fc)
210 detectors["MSP"] = MaxSoftmax(model)
211 detectors["EnergyBased"] = EnergyBased(model)
212 detectors["GEN"] = GEN(model)
213 detectors["MaxLogit"] = MaxLogit(model)
214 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
215 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
216 detectors["RMD"] = RMD(model.features)
217
218 detectors["MultiMahalanobis"] = MultiMahalanobis(
219     [
220         model.conv1,
221         model.block1,
222         model.block2,
223         model.block3,
224         nn.Sequential(model.bn1, model.relu),
225     ]
226 )
227 detectors["Gram"] = Gram(
228     num_classes=10,
229     head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
230     feature_layers=[
231         model.conv1,
232         model.block1,
233         model.block2,
234         model.block3,
235         nn.Sequential(model.bn1, model.relu),
236     ],
237 )
238
239 detectors["NAC-UE"] = NACUE(
240     model=model,
241     layers=[model.block2, model.block3, model.bn1],
242     m_bins=[200, 200, 200],
243     alpha=[150.0, 200.0, 250.0],
244     o_star=[25, 50, 100],
245     device=device,
246 )

Stage 3: Fit detectors

250 print(f"STAGE 3: Fitting {len(detectors)} detectors")
251 loader_in_train = DataLoader(
252     CIFAR10(root="data", train=True, transform=trans), batch_size=128, num_workers=12
253 )
254
255 print("Extracting training logits")
256 train_logits, train_labels_logits = extract_features(loader_in_train, model, device)
257 print("Extracting training pooled features")
258 train_features, train_labels_features = extract_features(loader_in_train, model.features, device)
259
260 assert torch.equal(train_labels_logits, train_labels_features)
261 train_labels = train_labels_logits
262 train_cache = {
263     "logits": train_logits,
264     "features": train_features,
265 }
266
267 for name, detector in detectors.items():
268     print(f"--> Fitting {name}")
269     fit_detector(detector, loader_in_train, train_cache, train_labels, device)

Stage 4: Evaluate detectors

273 print(f"STAGE 4: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
274 results = []
275
276 for dataset_name, loader in datasets.items():
277     print(f"> Caching representations for {dataset_name}")
278     eval_cache = cache_representations(loader, model, model.features, device)
279
280     for detector_name, detector in detectors.items():
281         print(f"--> {detector_name}")
282         scores = evaluate_detector(detector, loader, eval_cache, device)
283         result = {"Detector": detector_name, "Dataset": dataset_name}
284         result.update(scores)
285         results.append(result)
286
287     del eval_cache
288     gc.collect()
289     if torch.cuda.is_available():
290         torch.cuda.empty_cache()
291
292 df = pd.DataFrame(results)
293 mean_scores = (
294     df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
295 )
296 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery