CIFAR 10

Example benchmark code for CIFAR10

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

GradNorm

50.00

60.78

18.37

81.63

100.00

Gram

69.37

46.01

58.02

77.49

75.03

KLMatching

88.48

39.83

72.29

91.33

57.84

NAC-UE

88.74

39.89

81.40

90.36

46.12

SHE

90.08

39.69

69.17

92.92

38.48

MSP

91.41

37.07

86.36

92.42

29.93

Entropy

92.03

35.90

86.70

93.47

29.75

Mahalanobis

92.14

42.76

86.39

94.36

28.25

ODIN

92.14

47.06

84.98

94.46

34.43

ViM

92.32

40.22

85.76

94.93

29.49

Mahalanobis+ODIN

92.60

42.76

86.81

95.08

27.11

KNN

92.67

36.61

87.07

94.21

29.49

DICE

92.80

35.83

86.68

94.20

32.35

MultiMahalanobis

92.89

45.35

85.51

96.09

24.95

MaxLogit

93.05

35.84

87.01

94.40

31.31

ASH

93.06

35.70

77.62

94.43

31.31

EnergyBased

93.11

35.45

87.09

94.46

31.14

RMD

93.46

32.09

87.73

95.08

26.99

52 import pandas as pd  # additional dependency, used here for convenience
53 from torch import nn
54 from torch.utils.data import DataLoader
55 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
56 from copy import deepcopy
57 from tqdm.auto import tqdm  # additional dependency, used here for convenience
58 import torch
59
60 from pytorch_ood.dataset.img import (
61     LSUNCrop,
62     LSUNResize,
63     Textures,
64     TinyImageNetCrop,
65     TinyImageNetResize,
66     Places365,
67 )
68 from pytorch_ood.detector import (
69     ODIN,
70     EnergyBased,
71     Entropy,
72     GEN,
73     KLMatching,
74     Mahalanobis,
75     MaxLogit,
76     MaxSoftmax,
77     ViM,
78     RMD,
79     DICE,
80     SHE,
81     Gram,
82     GMM,
83     MultiMahalanobis,
84     NACUE,
85     GradNorm,
86     GradNormKL,
87     ASH,
88     KNN,
89     RankFeat,
90     fDBD,
91 )
92 from pytorch_ood.model import WideResNet
93 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
94
95 device = "cuda:0"
96
97 fix_random_seed(123)

Setup preprocessing

101 trans = WideResNet.transform_for("cifar10-pt")
102 norm_std = WideResNet.norm_std_for("cifar10-pt")

Setup datasets

107 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
108
109 # create all OOD datasets
110 ood_datasets = [
111     Textures,
112     TinyImageNetCrop,
113     TinyImageNetResize,
114     LSUNCrop,
115     LSUNResize,
116     Places365,
117     CIFAR100,
118     MNIST,
119     FashionMNIST,
120 ]
121 datasets = {}
122 for ood_dataset in ood_datasets:
123     dataset_out_test = ood_dataset(
124         root="data", transform=trans, target_transform=ToUnknown(), download=True
125     )
126     test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=12)
127     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

131 print("STAGE 1: Creating a Model")
132 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)

Stage 2: Create OOD detector

136 print("STAGE 2: Creating OOD Detectors")
137 detectors = {}
138
139 detectors["KNN"] = KNN(model.features)
140 detectors["GMM"] = GMM(model.features)
141 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
142
143 detectors["ASH"] = ASH(backbone=model.features_before_pool, head=model.forward_from_before_pool)
144 detectors["RankFeat"] = RankFeat(
145     backbone=model.features_before_pool, head=model.forward_from_before_pool
146 )
147
148 # we make a copy of the model just so deactivating gradients does not influence other detectors
149 model_gn = deepcopy(model)
150 model_gn.requires_grad_(False)
151 model_gn.fc.requires_grad_(True)
152 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
153
154 model_gnkl = deepcopy(model)
155 model_gnkl.requires_grad_(False)
156 model_gnkl.fc.requires_grad_(True)
157 detectors["GradNormKL"] = GradNormKL(model_gnkl, param_filter=lambda name: name.startswith("fc"))
158
159 detectors["Entropy"] = Entropy(model)
160 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
161 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
162 detectors["Mahalanobis"] = Mahalanobis(model.features)
163
164 detectors["KLMatching"] = KLMatching(model)
165 detectors["SHE"] = SHE(model.features, model.fc)
166 detectors["MSP"] = MaxSoftmax(model)
167 detectors["EnergyBased"] = EnergyBased(model)
168 detectors["GEN"] = GEN(model)
169 detectors["MaxLogit"] = MaxLogit(model)
170 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
171 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
172 detectors["RMD"] = RMD(model.features)
173
174 detectors["MultiMahalanobis"] = MultiMahalanobis(
175     [
176         model.conv1,
177         model.block1,
178         model.block2,
179         model.block3,
180         nn.Sequential(model.bn1, model.relu),
181     ]
182 )
183 detectors["Gram"] = Gram(
184     num_classes=10,
185     head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
186     feature_layers=[
187         model.conv1,
188         model.block1,
189         model.block2,
190         model.block3,
191         nn.Sequential(model.bn1, model.relu),
192     ],
193 )
194
195 # hyperparameters determined on Textures dataset
196 detectors["NAC-UE"] = NACUE(
197     model=model,
198     layers=[model.block2, model.block3, model.bn1],
199     m_bins=[200, 200, 200],
200     alpha=[150.0, 200.0, 250.0],
201     o_star=[25, 50, 100],
202     device=device,
203 )
204
205 # fit detectors to training data (some require this, some do not)
206 print(f"> Fitting {len(detectors)} detectors")
207 loader_in_train = DataLoader(
208     CIFAR10(root="data", train=True, transform=trans), batch_size=128, num_workers=12
209 )
210 for name, detector in detectors.items():
211     print(f"--> Fitting {name}")
212     detector.to(device)
213     detector.fit(loader_in_train)

Stage 3: Evaluate Detectors

217 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
218 results = []
219
220
221 with torch.no_grad():
222     for detector_name, detector in detectors.items():
223         print(f"> Evaluating {detector_name}")
224         for dataset_name, loader in datasets.items():
225             print(f"--> {dataset_name}")
226             metrics = OODMetrics()
227             for x, y in tqdm(loader, desc=dataset_name):
228                 metrics.update(detector(x.to(device)), y.to(device))
229
230         r = {"Detector": detector_name, "Dataset": dataset_name}
231         r.update(metrics.compute())
232         results.append(r)
233
234 # calculate mean scores over all datasets, use percent
235 df = pd.DataFrame(results)
236 mean_scores = (
237     df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
238 )
239 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery