Note
Go to the end to download the full example code.
CIFAR 10
Example benchmark code for CIFAR10
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
GradNorm |
50.00 |
60.78 |
18.37 |
81.63 |
100.00 |
Gram |
69.37 |
46.01 |
58.02 |
77.49 |
75.03 |
KLMatching |
88.48 |
39.83 |
72.29 |
91.33 |
57.84 |
NAC-UE |
88.74 |
39.89 |
81.40 |
90.36 |
46.12 |
SHE |
90.08 |
39.69 |
69.17 |
92.92 |
38.48 |
MSP |
91.41 |
37.07 |
86.36 |
92.42 |
29.93 |
Entropy |
92.03 |
35.90 |
86.70 |
93.47 |
29.75 |
Mahalanobis |
92.14 |
42.76 |
86.39 |
94.36 |
28.25 |
ODIN |
92.14 |
47.06 |
84.98 |
94.46 |
34.43 |
ViM |
92.32 |
40.22 |
85.76 |
94.93 |
29.49 |
Mahalanobis+ODIN |
92.60 |
42.76 |
86.81 |
95.08 |
27.11 |
KNN |
92.67 |
36.61 |
87.07 |
94.21 |
29.49 |
DICE |
92.80 |
35.83 |
86.68 |
94.20 |
32.35 |
MultiMahalanobis |
92.89 |
45.35 |
85.51 |
96.09 |
24.95 |
MaxLogit |
93.05 |
35.84 |
87.01 |
94.40 |
31.31 |
ASH |
93.06 |
35.70 |
77.62 |
94.43 |
31.31 |
EnergyBased |
93.11 |
35.45 |
87.09 |
94.46 |
31.14 |
RMD |
93.46 |
32.09 |
87.73 |
95.08 |
26.99 |
52 import pandas as pd # additional dependency, used here for convenience
53 from torch import nn
54 from torch.utils.data import DataLoader
55 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
56 from copy import deepcopy
57 from tqdm.auto import tqdm # additional dependency, used here for convenience
58 import torch
59
60 from pytorch_ood.dataset.img import (
61 LSUNCrop,
62 LSUNResize,
63 Textures,
64 TinyImageNetCrop,
65 TinyImageNetResize,
66 Places365,
67 )
68 from pytorch_ood.detector import (
69 ODIN,
70 EnergyBased,
71 Entropy,
72 GEN,
73 KLMatching,
74 Mahalanobis,
75 MaxLogit,
76 MaxSoftmax,
77 ViM,
78 RMD,
79 DICE,
80 SHE,
81 Gram,
82 GMM,
83 MultiMahalanobis,
84 NACUE,
85 GradNorm,
86 GradNormKL,
87 ASH,
88 KNN,
89 RankFeat,
90 fDBD,
91 )
92 from pytorch_ood.model import WideResNet
93 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
94
95 device = "cuda:0"
96
97 fix_random_seed(123)
Setup preprocessing
101 trans = WideResNet.transform_for("cifar10-pt")
102 norm_std = WideResNet.norm_std_for("cifar10-pt")
Setup datasets
107 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
108
109 # create all OOD datasets
110 ood_datasets = [
111 Textures,
112 TinyImageNetCrop,
113 TinyImageNetResize,
114 LSUNCrop,
115 LSUNResize,
116 Places365,
117 CIFAR100,
118 MNIST,
119 FashionMNIST,
120 ]
121 datasets = {}
122 for ood_dataset in ood_datasets:
123 dataset_out_test = ood_dataset(
124 root="data", transform=trans, target_transform=ToUnknown(), download=True
125 )
126 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=12)
127 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
131 print("STAGE 1: Creating a Model")
132 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
Stage 2: Create OOD detector
136 print("STAGE 2: Creating OOD Detectors")
137 detectors = {}
138
139 detectors["KNN"] = KNN(model.features)
140 detectors["GMM"] = GMM(model.features)
141 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
142
143 detectors["ASH"] = ASH(backbone=model.features_before_pool, head=model.forward_from_before_pool)
144 detectors["RankFeat"] = RankFeat(
145 backbone=model.features_before_pool, head=model.forward_from_before_pool
146 )
147
148 # we make a copy of the model just so deactivating gradients does not influence other detectors
149 model_gn = deepcopy(model)
150 model_gn.requires_grad_(False)
151 model_gn.fc.requires_grad_(True)
152 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
153
154 model_gnkl = deepcopy(model)
155 model_gnkl.requires_grad_(False)
156 model_gnkl.fc.requires_grad_(True)
157 detectors["GradNormKL"] = GradNormKL(model_gnkl, param_filter=lambda name: name.startswith("fc"))
158
159 detectors["Entropy"] = Entropy(model)
160 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
161 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
162 detectors["Mahalanobis"] = Mahalanobis(model.features)
163
164 detectors["KLMatching"] = KLMatching(model)
165 detectors["SHE"] = SHE(model.features, model.fc)
166 detectors["MSP"] = MaxSoftmax(model)
167 detectors["EnergyBased"] = EnergyBased(model)
168 detectors["GEN"] = GEN(model)
169 detectors["MaxLogit"] = MaxLogit(model)
170 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
171 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
172 detectors["RMD"] = RMD(model.features)
173
174 detectors["MultiMahalanobis"] = MultiMahalanobis(
175 [
176 model.conv1,
177 model.block1,
178 model.block2,
179 model.block3,
180 nn.Sequential(model.bn1, model.relu),
181 ]
182 )
183 detectors["Gram"] = Gram(
184 num_classes=10,
185 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
186 feature_layers=[
187 model.conv1,
188 model.block1,
189 model.block2,
190 model.block3,
191 nn.Sequential(model.bn1, model.relu),
192 ],
193 )
194
195 # hyperparameters determined on Textures dataset
196 detectors["NAC-UE"] = NACUE(
197 model=model,
198 layers=[model.block2, model.block3, model.bn1],
199 m_bins=[200, 200, 200],
200 alpha=[150.0, 200.0, 250.0],
201 o_star=[25, 50, 100],
202 device=device,
203 )
204
205 # fit detectors to training data (some require this, some do not)
206 print(f"> Fitting {len(detectors)} detectors")
207 loader_in_train = DataLoader(
208 CIFAR10(root="data", train=True, transform=trans), batch_size=128, num_workers=12
209 )
210 for name, detector in detectors.items():
211 print(f"--> Fitting {name}")
212 detector.to(device)
213 detector.fit(loader_in_train)
Stage 3: Evaluate Detectors
217 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
218 results = []
219
220
221 with torch.no_grad():
222 for detector_name, detector in detectors.items():
223 print(f"> Evaluating {detector_name}")
224 for dataset_name, loader in datasets.items():
225 print(f"--> {dataset_name}")
226 metrics = OODMetrics()
227 for x, y in tqdm(loader, desc=dataset_name):
228 metrics.update(detector(x.to(device)), y.to(device))
229
230 r = {"Detector": detector_name, "Dataset": dataset_name}
231 r.update(metrics.compute())
232 results.append(r)
233
234 # calculate mean scores over all datasets, use percent
235 df = pd.DataFrame(results)
236 mean_scores = (
237 df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
238 )
239 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))