CIFAR 10

Example benchmark code for CIFAR10

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

GEN

93.38

29.75

87.80

94.62

29.50

EnergyBased

93.11

35.45

87.09

94.46

31.14

ASH

93.06

35.70

77.62

94.43

31.31

MaxLogit

93.05

35.84

87.01

94.40

31.31

MultiMahalanobis

92.89

45.35

85.51

96.09

24.95

DICE

92.80

35.83

86.68

94.20

32.35

KNN

92.67

36.61

87.07

94.21

29.49

RMD

92.61

31.24

87.21

93.81

27.96

Mahalanobis+ODIN

92.60

42.76

86.81

95.08

27.11

ViM

92.31

40.25

85.77

94.93

29.48

ODIN

92.14

47.06

84.98

94.46

34.32

Mahalanobis

91.82

42.93

86.21

93.85

28.60

fDBD

91.82

35.31

83.54

93.98

35.72

Entropy

92.03

35.90

86.70

93.47

29.75

MSP

91.41

37.07

86.36

92.42

29.93

SHE

90.08

39.69

69.17

92.92

38.48

GMM

89.99

42.95

85.59

90.26

30.42

NAC-UE

88.74

39.89

81.40

90.36

46.12

KLMatching

88.48

39.83

72.29

91.33

57.84

GradNormKL

80.97

49.97

68.70

89.49

79.72

Gram

69.37

46.01

58.02

77.49

75.03

RankFeat

55.43

49.92

45.31

63.80

86.35

GradNorm

50.00

60.78

18.37

81.63

100.00

 62 import pandas as pd  # additional dependency, used here for convenience
 63 from torch import nn
 64 from torch.utils.data import DataLoader
 65 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
 66 from copy import deepcopy
 67 from tqdm.auto import tqdm  # additional dependency, used here for convenience
 68 import torch
 69
 70 from pytorch_ood.dataset.img import (
 71     LSUNCrop,
 72     LSUNResize,
 73     Textures,
 74     TinyImageNetCrop,
 75     TinyImageNetResize,
 76     Places365,
 77 )
 78 from pytorch_ood.detector import (
 79     ODIN,
 80     EnergyBased,
 81     Entropy,
 82     GEN,
 83     KLMatching,
 84     Mahalanobis,
 85     MahalanobisODIN,
 86     MaxLogit,
 87     MaxSoftmax,
 88     ViM,
 89     RMD,
 90     DICE,
 91     SHE,
 92     Gram,
 93     GMM,
 94     MultiMahalanobis,
 95     NACUE,
 96     GradNorm,
 97     GradNormKL,
 98     ASH,
 99     KNN,
100     RankFeat,
101     fDBD,
102 )
103 from pytorch_ood.model import WideResNet
104 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
105
106 device = "cuda:0"
107
108 fix_random_seed(123)

Setup preprocessing

112 trans = WideResNet.transform_for("cifar10-pt")
113 norm_std = WideResNet.norm_std_for("cifar10-pt")

Setup datasets

118 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
119
120 # create all OOD datasets
121 ood_datasets = [
122     Textures,
123     TinyImageNetCrop,
124     TinyImageNetResize,
125     LSUNCrop,
126     LSUNResize,
127     Places365,
128     CIFAR100,
129     MNIST,
130     FashionMNIST,
131 ]
132 datasets = {}
133 for ood_dataset in ood_datasets:
134     dataset_out_test = ood_dataset(
135         root="data", transform=trans, target_transform=ToUnknown(), download=True
136     )
137     test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=12)
138     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

142 print("STAGE 1: Creating a Model")
143 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)

Stage 2: Create OOD detector

147 print("STAGE 2: Creating OOD Detectors")
148 detectors = {}
149
150 detectors["KNN"] = KNN(model.features)
151 detectors["GMM"] = GMM(model.features)
152 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
153
154 detectors["ASH"] = ASH(backbone=model.feature_maps, head=model.forward_feature_maps)
155 detectors["RankFeat"] = RankFeat(backbone=model.feature_maps, head=model.forward_feature_maps)
156
157 # we make a copy of the model just so deactivating gradients does not influence other detectors
158 model_gn = deepcopy(model)
159 model_gn.requires_grad_(False)
160 model_gn.fc.requires_grad_(True)
161 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
162
163 model_gnkl = deepcopy(model)
164 model_gnkl.requires_grad_(False)
165 model_gnkl.fc.requires_grad_(True)
166 detectors["GradNormKL"] = GradNormKL(model_gnkl, param_filter=lambda name: name.startswith("fc"))
167
168 detectors["Entropy"] = Entropy(model)
169 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
170 detectors["Mahalanobis+ODIN"] = MahalanobisODIN(model.features, norm_std=norm_std, eps=0.002)
171 detectors["Mahalanobis"] = Mahalanobis(model.features)
172
173 detectors["KLMatching"] = KLMatching(model)
174 detectors["SHE"] = SHE(model.features, model.fc)
175 detectors["MSP"] = MaxSoftmax(model)
176 detectors["EnergyBased"] = EnergyBased(model)
177 detectors["GEN"] = GEN(model)
178 detectors["MaxLogit"] = MaxLogit(model)
179 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
180 detectors["DICE"] = DICE(encoder=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
181 detectors["RMD"] = RMD(model.features)
182
183 detectors["MultiMahalanobis"] = MultiMahalanobis(
184     [
185         model.conv1,
186         model.block1,
187         model.block2,
188         model.block3,
189         nn.Sequential(model.bn1, model.relu),
190     ]
191 )
192 detectors["Gram"] = Gram(
193     num_classes=10,
194     head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
195     feature_layers=[
196         model.conv1,
197         model.block1,
198         model.block2,
199         model.block3,
200         nn.Sequential(model.bn1, model.relu),
201     ],
202 )
203
204 # hyperparameters determined on Textures dataset
205 detectors["NAC-UE"] = NACUE(
206     model=model,
207     layers=[model.block2, model.block3, model.bn1],
208     m_bins=[200, 200, 200],
209     alpha=[150.0, 200.0, 250.0],
210     o_star=[25, 50, 100],
211     device=device,
212 )
213
214 # fit detectors to training data (some require this, some do not)
215 print(f"> Fitting {len(detectors)} detectors")
216 loader_in_train = DataLoader(
217     CIFAR10(root="data", train=True, transform=trans), batch_size=128, num_workers=12
218 )
219 for name, detector in detectors.items():
220     print(f"--> Fitting {name}")
221     detector.to(device)
222     detector.fit(loader_in_train)

Stage 3: Evaluate Detectors

226 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
227 results = []
228
229
230 with torch.no_grad():
231     for detector_name, detector in detectors.items():
232         print(f"> Evaluating {detector_name}")
233         for dataset_name, loader in datasets.items():
234             print(f"--> {dataset_name}")
235             metrics = OODMetrics()
236             for x, y in tqdm(loader, desc=dataset_name):
237                 metrics.update(detector(x.to(device)), y.to(device))
238
239             r = {"Detector": detector_name, "Dataset": dataset_name}
240             r.update(metrics.compute())
241             results.append(r)
242
243 # calculate mean scores over all datasets, use percent
244 df = pd.DataFrame(results)
245 mean_scores = (
246     df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
247 )
248 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery