Note
Go to the end to download the full example code.
CIFAR 10
Example benchmark code for CIFAR10
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
GradNorm |
50.00 |
60.78 |
18.37 |
81.63 |
100.00 |
Gram |
69.37 |
46.01 |
58.02 |
77.49 |
75.03 |
KLMatching |
88.48 |
39.83 |
72.29 |
91.33 |
57.84 |
NAC-UE |
88.74 |
39.89 |
81.40 |
90.36 |
46.12 |
SHE |
90.08 |
39.69 |
69.17 |
92.92 |
38.48 |
MSP |
91.41 |
37.07 |
86.36 |
92.42 |
29.93 |
Entropy |
92.03 |
35.90 |
86.70 |
93.47 |
29.75 |
Mahalanobis |
92.14 |
42.76 |
86.39 |
94.36 |
28.25 |
ODIN |
92.14 |
47.06 |
84.98 |
94.46 |
34.43 |
ViM |
92.32 |
40.22 |
85.76 |
94.93 |
29.49 |
Mahalanobis+ODIN |
92.60 |
42.76 |
86.81 |
95.08 |
27.11 |
KNN |
92.67 |
36.61 |
87.07 |
94.21 |
29.49 |
DICE |
92.80 |
35.83 |
86.68 |
94.20 |
32.35 |
MultiMahalanobis |
92.89 |
45.35 |
85.51 |
96.09 |
24.95 |
MaxLogit |
93.05 |
35.84 |
87.01 |
94.40 |
31.31 |
ASH |
93.06 |
35.70 |
77.62 |
94.43 |
31.31 |
EnergyBased |
93.11 |
35.45 |
87.09 |
94.46 |
31.14 |
RMD |
93.46 |
32.09 |
87.73 |
95.08 |
26.99 |
52 import pandas as pd # additional dependency, used here for convenience
53 from torch import nn
54 from torch.utils.data import DataLoader
55 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
56 from copy import deepcopy
57 from tqdm.auto import tqdm # additional dependency, used here for convenience
58 import torch
59
60 from pytorch_ood.dataset.img import (
61 LSUNCrop,
62 LSUNResize,
63 Textures,
64 TinyImageNetCrop,
65 TinyImageNetResize,
66 Places365,
67 )
68 from pytorch_ood.detector import (
69 ODIN,
70 EnergyBased,
71 Entropy,
72 KLMatching,
73 Mahalanobis,
74 MaxLogit,
75 MaxSoftmax,
76 ViM,
77 RMD,
78 DICE,
79 SHE,
80 Gram,
81 MultiMahalanobis,
82 NACUE,
83 GradNorm,
84 ASH,
85 KNN,
86 )
87 from pytorch_ood.model import WideResNet
88 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
89
90 device = "cuda:0"
91
92 fix_random_seed(123)
Setup preprocessing
96 trans = WideResNet.transform_for("cifar10-pt")
97 norm_std = WideResNet.norm_std_for("cifar10-pt")
Setup datasets
102 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
103
104 # create all OOD datasets
105 ood_datasets = [
106 Textures,
107 TinyImageNetCrop,
108 TinyImageNetResize,
109 LSUNCrop,
110 LSUNResize,
111 Places365,
112 CIFAR100,
113 MNIST,
114 FashionMNIST,
115 ]
116 datasets = {}
117 for ood_dataset in ood_datasets:
118 dataset_out_test = ood_dataset(
119 root="data", transform=trans, target_transform=ToUnknown(), download=True
120 )
121 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=256, num_workers=12)
122 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
126 print("STAGE 1: Creating a Model")
127 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
Stage 2: Create OOD detector
131 print("STAGE 2: Creating OOD Detectors")
132 detectors = {}
133
134 detectors["KNN"] = KNN(model.features)
135
136 detectors["ASH"] = ASH(backbone=model.features_before_pool, head=model.forward_from_before_pool)
137
138 # we make a copy of the model just so deactivating gradients does not influence other detectors
139 model_gn = deepcopy(model)
140 model_gn.requires_grad_(False)
141 model_gn.fc.requires_grad_(True)
142 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
143
144 detectors["Entropy"] = Entropy(model)
145 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
146 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
147 detectors["Mahalanobis"] = Mahalanobis(model.features)
148
149 detectors["KLMatching"] = KLMatching(model)
150 detectors["SHE"] = SHE(model.features, model.fc)
151 detectors["MSP"] = MaxSoftmax(model)
152 detectors["EnergyBased"] = EnergyBased(model)
153 detectors["MaxLogit"] = MaxLogit(model)
154 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
155 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
156 detectors["RMD"] = RMD(model.features)
157
158 detectors["MultiMahalanobis"] = MultiMahalanobis(
159 [
160 model.conv1,
161 model.block1,
162 model.block2,
163 model.block3,
164 nn.Sequential(model.bn1, model.relu),
165 ]
166 )
167 detectors["Gram"] = Gram(
168 num_classes=10,
169 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
170 feature_layers=[
171 model.conv1,
172 model.block1,
173 model.block2,
174 model.block3,
175 nn.Sequential(model.bn1, model.relu),
176 ],
177 )
178
179 # hyperparameters determined on Textures dataset
180 detectors["NAC-UE"] = NACUE(
181 model=model,
182 layers=[model.block2, model.block3, model.bn1],
183 m_bins=[200, 200, 200],
184 alpha=[150.0, 200.0, 250.0],
185 o_star=[25, 50, 100],
186 device=device,
187 )
188
189 # fit detectors to training data (some require this, some do not)
190 print(f"> Fitting {len(detectors)} detectors")
191 loader_in_train = DataLoader(
192 CIFAR10(root="data", train=True, transform=trans), batch_size=128, num_workers=12
193 )
194 for name, detector in detectors.items():
195 print(f"--> Fitting {name}")
196 detector.fit(loader_in_train, device=device)
Stage 3: Evaluate Detectors
200 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
201 results = []
202
203
204 with torch.no_grad():
205 for detector_name, detector in detectors.items():
206 print(f"> Evaluating {detector_name}")
207 for dataset_name, loader in datasets.items():
208 print(f"--> {dataset_name}")
209 metrics = OODMetrics()
210 for x, y in tqdm(loader, desc=dataset_name):
211 metrics.update(detector(x.to(device)), y.to(device))
212
213 r = {"Detector": detector_name, "Dataset": dataset_name}
214 r.update(metrics.compute())
215 results.append(r)
216
217 # calculate mean scores over all datasets, use percent
218 df = pd.DataFrame(results)
219 mean_scores = (
220 df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
221 )
222 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))