Note
Go to the end to download the full example code.
CIFAR 10 Cached Baseline
Benchmark code for CIFAR-10 that reuses cached logits and pooled features
for fitting detectors that expose the standard representation APIs. At
evaluation time, cached logits are used for true logits detectors, while
feature detectors that operate directly on pooled features can use the
cached feature representation. Other detectors still use their regular
predict() interface so their full model-time behavior is preserved.
15 from copy import deepcopy
16 import gc
17 import pandas as pd # additional dependency, used here for convenience
18 import torch
19 from torch import nn
20 from torch.utils.data import DataLoader
21 from torchvision.datasets import CIFAR10, CIFAR100, FashionMNIST, MNIST
22 from tqdm.auto import tqdm # additional dependency, used here for convenience
23
24 from pytorch_ood.dataset.img import (
25 LSUNCrop,
26 LSUNResize,
27 Places365,
28 Textures,
29 TinyImageNetCrop,
30 TinyImageNetResize,
31 )
32 from pytorch_ood.api import FeaturesDetector, LogitsDetector
33 from pytorch_ood.detector import (
34 ASH,
35 DICE,
36 EnergyBased,
37 Entropy,
38 GEN,
39 GMM,
40 GradNorm,
41 GradNormKL,
42 Gram,
43 KLMatching,
44 KNN,
45 Mahalanobis,
46 MaxLogit,
47 MaxSoftmax,
48 MultiMahalanobis,
49 NACUE,
50 ODIN,
51 RMD,
52 RankFeat,
53 SHE,
54 ViM,
55 fDBD,
56 )
57 from pytorch_ood.model import WideResNet
58 from pytorch_ood.utils import (
59 OODMetrics,
60 TensorBuffer,
61 ToUnknown,
62 extract_features,
63 fix_random_seed,
64 )
65
66 device = "cuda:0"
67
68 fix_random_seed(123)
69
70
71 def cache_representations(
72 data_loader: DataLoader, model: nn.Module, feature_model: nn.Module, device: str
73 ):
74 """
75 Cache logits, pooled features, and labels for all samples in a loader.
76 Unlike ``extract_features(...)``, this keeps OOD samples and labels.
77 """
78 print(f"Extracting cached logits/features from {len(data_loader.dataset)} samples")
79 buffer = TensorBuffer()
80
81 with torch.no_grad():
82 for x, y in data_loader:
83 x = x.to(device)
84 buffer.append("logits", model(x).view(x.shape[0], -1))
85 buffer.append("features", feature_model(x).view(x.shape[0], -1))
86 buffer.append("label", y)
87
88 return {
89 "logits": buffer.get("logits"),
90 "features": buffer.get("features"),
91 "labels": buffer.get("label"),
92 }
93
94
95 def fit_detector(
96 detector, train_loader: DataLoader, train_cache: dict, train_labels: torch.Tensor, device: str
97 ):
98 """
99 Fit using cached representations for detectors with a standard semantic
100 interface. Otherwise fall back to ``fit(...)``.
101 """
102 if detector.requires_fit and isinstance(detector, LogitsDetector):
103 detector.fit_logits(train_cache["logits"], train_labels)
104 return
105
106 if detector.requires_fit and isinstance(detector, FeaturesDetector):
107 detector.fit_features(train_cache["features"], train_labels)
108 return
109
110 if detector.requires_fit:
111 detector.to(device)
112 detector.fit(train_loader)
113
114
115 def evaluate_detector(detector, data_loader: DataLoader, eval_cache: dict, device: str) -> dict:
116 """
117 Evaluate using cached logits for logits detectors and cached pooled
118 features for feature detectors whose ``predict(x)`` does not add extra
119 input preprocessing. Otherwise fall back to ``predict(x)``.
120 """
121 metrics = OODMetrics()
122 detector.to(device)
123
124 if isinstance(detector, LogitsDetector):
125 scores = detector.predict_logits(eval_cache["logits"])
126 metrics.update(scores, eval_cache["labels"].to(scores.device))
127 return metrics.compute()
128
129 if isinstance(detector, FeaturesDetector):
130 if isinstance(detector, Mahalanobis) and detector.eps > 0:
131 for x, y in tqdm(data_loader, leave=False):
132 metrics.update(detector(x.to(device)), y.to(device))
133 else:
134 scores = detector.predict_features(eval_cache["features"])
135 metrics.update(scores, eval_cache["labels"].to(scores.device))
136
137 return metrics.compute()
138
139 for x, y in tqdm(data_loader, leave=False):
140 metrics.update(detector(x.to(device)), y.to(device))
141
142 return metrics.compute()
Setup preprocessing
147 trans = WideResNet.transform_for("cifar10-pt")
148 norm_std = WideResNet.norm_std_for("cifar10-pt")
Setup datasets
153 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
154
155 ood_datasets = [
156 Textures,
157 TinyImageNetCrop,
158 TinyImageNetResize,
159 LSUNCrop,
160 LSUNResize,
161 Places365,
162 CIFAR100,
163 MNIST,
164 FashionMNIST,
165 ]
166 datasets = {}
167 for ood_dataset in ood_datasets:
168 dataset_out_test = ood_dataset(
169 root="data", transform=trans, target_transform=ToUnknown(), download=True
170 )
171 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=12)
172 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create model
176 print("STAGE 1: Creating a Model")
177 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
Stage 2: Create detectors
181 print("STAGE 2: Creating OOD Detectors")
182 detectors = {}
183
184 detectors["KNN"] = KNN(model.features)
185 detectors["GMM"] = GMM(model.features)
186 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
187
188 detectors["ASH"] = ASH(backbone=model.features_before_pool, head=model.forward_from_before_pool)
189 detectors["RankFeat"] = RankFeat(
190 backbone=model.features_before_pool, head=model.forward_from_before_pool
191 )
192
193 model_gn = deepcopy(model)
194 model_gn.requires_grad_(False)
195 model_gn.fc.requires_grad_(True)
196 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
197
198 model_gnkl = deepcopy(model)
199 model_gnkl.requires_grad_(False)
200 model_gnkl.fc.requires_grad_(True)
201 detectors["GradNormKL"] = GradNormKL(model_gnkl, param_filter=lambda name: name.startswith("fc"))
202
203 detectors["Entropy"] = Entropy(model)
204 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
205 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
206 detectors["Mahalanobis"] = Mahalanobis(model.features)
207
208 detectors["KLMatching"] = KLMatching(model)
209 detectors["SHE"] = SHE(model.features, model.fc)
210 detectors["MSP"] = MaxSoftmax(model)
211 detectors["EnergyBased"] = EnergyBased(model)
212 detectors["GEN"] = GEN(model)
213 detectors["MaxLogit"] = MaxLogit(model)
214 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
215 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
216 detectors["RMD"] = RMD(model.features)
217
218 detectors["MultiMahalanobis"] = MultiMahalanobis(
219 [
220 model.conv1,
221 model.block1,
222 model.block2,
223 model.block3,
224 nn.Sequential(model.bn1, model.relu),
225 ]
226 )
227 detectors["Gram"] = Gram(
228 num_classes=10,
229 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
230 feature_layers=[
231 model.conv1,
232 model.block1,
233 model.block2,
234 model.block3,
235 nn.Sequential(model.bn1, model.relu),
236 ],
237 )
238
239 detectors["NAC-UE"] = NACUE(
240 model=model,
241 layers=[model.block2, model.block3, model.bn1],
242 m_bins=[200, 200, 200],
243 alpha=[150.0, 200.0, 250.0],
244 o_star=[25, 50, 100],
245 device=device,
246 )
Stage 3: Fit detectors
250 print(f"STAGE 3: Fitting {len(detectors)} detectors")
251 loader_in_train = DataLoader(
252 CIFAR10(root="data", train=True, transform=trans), batch_size=128, num_workers=12
253 )
254
255 print("Extracting training logits")
256 train_logits, train_labels_logits = extract_features(loader_in_train, model, device)
257 print("Extracting training pooled features")
258 train_features, train_labels_features = extract_features(loader_in_train, model.features, device)
259
260 assert torch.equal(train_labels_logits, train_labels_features)
261 train_labels = train_labels_logits
262 train_cache = {
263 "logits": train_logits,
264 "features": train_features,
265 }
266
267 for name, detector in detectors.items():
268 print(f"--> Fitting {name}")
269 fit_detector(detector, loader_in_train, train_cache, train_labels, device)
Stage 4: Evaluate detectors
273 print(f"STAGE 4: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
274 results = []
275
276 for dataset_name, loader in datasets.items():
277 print(f"> Caching representations for {dataset_name}")
278 eval_cache = cache_representations(loader, model, model.features, device)
279
280 for detector_name, detector in detectors.items():
281 print(f"--> {detector_name}")
282 scores = evaluate_detector(detector, loader, eval_cache, device)
283 result = {"Detector": detector_name, "Dataset": dataset_name}
284 result.update(scores)
285 results.append(result)
286
287 del eval_cache
288 gc.collect()
289 if torch.cuda.is_available():
290 torch.cuda.empty_cache()
291
292 df = pd.DataFrame(results)
293 mean_scores = (
294 df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
295 )
296 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))