Note
Go to the end to download the full example code.
CIFAR 10
Example benchmark code for CIFAR10
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
GEN |
93.38 |
29.75 |
87.80 |
94.62 |
29.50 |
EnergyBased |
93.11 |
35.45 |
87.09 |
94.46 |
31.14 |
ASH |
93.06 |
35.70 |
77.62 |
94.43 |
31.31 |
MaxLogit |
93.05 |
35.84 |
87.01 |
94.40 |
31.31 |
MultiMahalanobis |
92.89 |
45.35 |
85.51 |
96.09 |
24.95 |
DICE |
92.80 |
35.83 |
86.68 |
94.20 |
32.35 |
KNN |
92.67 |
36.61 |
87.07 |
94.21 |
29.49 |
RMD |
92.61 |
31.24 |
87.21 |
93.81 |
27.96 |
Mahalanobis+ODIN |
92.60 |
42.76 |
86.81 |
95.08 |
27.11 |
ViM |
92.31 |
40.25 |
85.77 |
94.93 |
29.48 |
ODIN |
92.14 |
47.06 |
84.98 |
94.46 |
34.32 |
Mahalanobis |
91.82 |
42.93 |
86.21 |
93.85 |
28.60 |
fDBD |
91.82 |
35.31 |
83.54 |
93.98 |
35.72 |
Entropy |
92.03 |
35.90 |
86.70 |
93.47 |
29.75 |
MSP |
91.41 |
37.07 |
86.36 |
92.42 |
29.93 |
SHE |
90.08 |
39.69 |
69.17 |
92.92 |
38.48 |
GMM |
89.99 |
42.95 |
85.59 |
90.26 |
30.42 |
NAC-UE |
88.74 |
39.89 |
81.40 |
90.36 |
46.12 |
KLMatching |
88.48 |
39.83 |
72.29 |
91.33 |
57.84 |
GradNormKL |
80.97 |
49.97 |
68.70 |
89.49 |
79.72 |
Gram |
69.37 |
46.01 |
58.02 |
77.49 |
75.03 |
RankFeat |
55.43 |
49.92 |
45.31 |
63.80 |
86.35 |
GradNorm |
50.00 |
60.78 |
18.37 |
81.63 |
100.00 |
62 import pandas as pd # additional dependency, used here for convenience
63 from torch import nn
64 from torch.utils.data import DataLoader
65 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
66 from copy import deepcopy
67 from tqdm.auto import tqdm # additional dependency, used here for convenience
68 import torch
69
70 from pytorch_ood.dataset.img import (
71 LSUNCrop,
72 LSUNResize,
73 Textures,
74 TinyImageNetCrop,
75 TinyImageNetResize,
76 Places365,
77 )
78 from pytorch_ood.detector import (
79 ODIN,
80 EnergyBased,
81 Entropy,
82 GEN,
83 KLMatching,
84 Mahalanobis,
85 MahalanobisODIN,
86 MaxLogit,
87 MaxSoftmax,
88 ViM,
89 RMD,
90 DICE,
91 SHE,
92 Gram,
93 GMM,
94 MultiMahalanobis,
95 NACUE,
96 GradNorm,
97 GradNormKL,
98 ASH,
99 KNN,
100 RankFeat,
101 fDBD,
102 )
103 from pytorch_ood.model import WideResNet
104 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
105
106 device = "cuda:0"
107
108 fix_random_seed(123)
Setup preprocessing
112 trans = WideResNet.transform_for("cifar10-pt")
113 norm_std = WideResNet.norm_std_for("cifar10-pt")
Setup datasets
118 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
119
120 # create all OOD datasets
121 ood_datasets = [
122 Textures,
123 TinyImageNetCrop,
124 TinyImageNetResize,
125 LSUNCrop,
126 LSUNResize,
127 Places365,
128 CIFAR100,
129 MNIST,
130 FashionMNIST,
131 ]
132 datasets = {}
133 for ood_dataset in ood_datasets:
134 dataset_out_test = ood_dataset(
135 root="data", transform=trans, target_transform=ToUnknown(), download=True
136 )
137 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=12)
138 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
142 print("STAGE 1: Creating a Model")
143 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
Stage 2: Create OOD detector
147 print("STAGE 2: Creating OOD Detectors")
148 detectors = {}
149
150 detectors["KNN"] = KNN(model.features)
151 detectors["GMM"] = GMM(model.features)
152 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
153
154 detectors["ASH"] = ASH(backbone=model.feature_maps, head=model.forward_feature_maps)
155 detectors["RankFeat"] = RankFeat(backbone=model.feature_maps, head=model.forward_feature_maps)
156
157 # we make a copy of the model just so deactivating gradients does not influence other detectors
158 model_gn = deepcopy(model)
159 model_gn.requires_grad_(False)
160 model_gn.fc.requires_grad_(True)
161 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
162
163 model_gnkl = deepcopy(model)
164 model_gnkl.requires_grad_(False)
165 model_gnkl.fc.requires_grad_(True)
166 detectors["GradNormKL"] = GradNormKL(model_gnkl, param_filter=lambda name: name.startswith("fc"))
167
168 detectors["Entropy"] = Entropy(model)
169 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
170 detectors["Mahalanobis+ODIN"] = MahalanobisODIN(model.features, norm_std=norm_std, eps=0.002)
171 detectors["Mahalanobis"] = Mahalanobis(model.features)
172
173 detectors["KLMatching"] = KLMatching(model)
174 detectors["SHE"] = SHE(model.features, model.fc)
175 detectors["MSP"] = MaxSoftmax(model)
176 detectors["EnergyBased"] = EnergyBased(model)
177 detectors["GEN"] = GEN(model)
178 detectors["MaxLogit"] = MaxLogit(model)
179 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
180 detectors["DICE"] = DICE(encoder=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
181 detectors["RMD"] = RMD(model.features)
182
183 detectors["MultiMahalanobis"] = MultiMahalanobis(
184 [
185 model.conv1,
186 model.block1,
187 model.block2,
188 model.block3,
189 nn.Sequential(model.bn1, model.relu),
190 ]
191 )
192 detectors["Gram"] = Gram(
193 num_classes=10,
194 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
195 feature_layers=[
196 model.conv1,
197 model.block1,
198 model.block2,
199 model.block3,
200 nn.Sequential(model.bn1, model.relu),
201 ],
202 )
203
204 # hyperparameters determined on Textures dataset
205 detectors["NAC-UE"] = NACUE(
206 model=model,
207 layers=[model.block2, model.block3, model.bn1],
208 m_bins=[200, 200, 200],
209 alpha=[150.0, 200.0, 250.0],
210 o_star=[25, 50, 100],
211 device=device,
212 )
213
214 # fit detectors to training data (some require this, some do not)
215 print(f"> Fitting {len(detectors)} detectors")
216 loader_in_train = DataLoader(
217 CIFAR10(root="data", train=True, transform=trans), batch_size=128, num_workers=12
218 )
219 for name, detector in detectors.items():
220 print(f"--> Fitting {name}")
221 detector.to(device)
222 detector.fit(loader_in_train)
Stage 3: Evaluate Detectors
226 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
227 results = []
228
229
230 with torch.no_grad():
231 for detector_name, detector in detectors.items():
232 print(f"> Evaluating {detector_name}")
233 for dataset_name, loader in datasets.items():
234 print(f"--> {dataset_name}")
235 metrics = OODMetrics()
236 for x, y in tqdm(loader, desc=dataset_name):
237 metrics.update(detector(x.to(device)), y.to(device))
238
239 r = {"Detector": detector_name, "Dataset": dataset_name}
240 r.update(metrics.compute())
241 results.append(r)
242
243 # calculate mean scores over all datasets, use percent
244 df = pd.DataFrame(results)
245 mean_scores = (
246 df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
247 )
248 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))