CIFAR 10

Example benchmark code for CIFAR10

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

GradNorm

50.00

60.78

18.37

81.63

100.00

Gram

69.37

46.01

58.02

77.49

75.03

KLMatching

88.48

39.83

72.29

91.33

57.84

NAC-UE

88.74

39.89

81.40

90.36

46.12

SHE

90.08

39.69

69.17

92.92

38.48

MSP

91.41

37.07

86.36

92.42

29.93

Entropy

92.03

35.90

86.70

93.47

29.75

Mahalanobis

92.14

42.76

86.39

94.36

28.25

ODIN

92.14

47.06

84.98

94.46

34.43

ViM

92.32

40.22

85.76

94.93

29.49

Mahalanobis+ODIN

92.60

42.76

86.81

95.08

27.11

KNN

92.67

36.61

87.07

94.21

29.49

DICE

92.80

35.83

86.68

94.20

32.35

MultiMahalanobis

92.89

45.35

85.51

96.09

24.95

MaxLogit

93.05

35.84

87.01

94.40

31.31

ASH

93.06

35.70

77.62

94.43

31.31

EnergyBased

93.11

35.45

87.09

94.46

31.14

RMD

93.46

32.09

87.73

95.08

26.99

52 import pandas as pd  # additional dependency, used here for convenience
53 from torch import nn
54 from torch.utils.data import DataLoader
55 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
56 from copy import deepcopy
57 from tqdm.auto import tqdm  # additional dependency, used here for convenience
58 import torch
59
60 from pytorch_ood.dataset.img import (
61     LSUNCrop,
62     LSUNResize,
63     Textures,
64     TinyImageNetCrop,
65     TinyImageNetResize,
66     Places365,
67 )
68 from pytorch_ood.detector import (
69     ODIN,
70     EnergyBased,
71     Entropy,
72     KLMatching,
73     Mahalanobis,
74     MaxLogit,
75     MaxSoftmax,
76     ViM,
77     RMD,
78     DICE,
79     SHE,
80     Gram,
81     MultiMahalanobis,
82     NACUE,
83     GradNorm,
84     ASH,
85     KNN,
86 )
87 from pytorch_ood.model import WideResNet
88 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
89
90 device = "cuda:0"
91
92 fix_random_seed(123)

Setup preprocessing

96 trans = WideResNet.transform_for("cifar10-pt")
97 norm_std = WideResNet.norm_std_for("cifar10-pt")

Setup datasets

102 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
103
104 # create all OOD datasets
105 ood_datasets = [
106     Textures,
107     TinyImageNetCrop,
108     TinyImageNetResize,
109     LSUNCrop,
110     LSUNResize,
111     Places365,
112     CIFAR100,
113     MNIST,
114     FashionMNIST,
115 ]
116 datasets = {}
117 for ood_dataset in ood_datasets:
118     dataset_out_test = ood_dataset(
119         root="data", transform=trans, target_transform=ToUnknown(), download=True
120     )
121     test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=256, num_workers=12)
122     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

126 print("STAGE 1: Creating a Model")
127 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)

Stage 2: Create OOD detector

131 print("STAGE 2: Creating OOD Detectors")
132 detectors = {}
133
134 detectors["KNN"] = KNN(model.features)
135
136 detectors["ASH"] = ASH(backbone=model.features_before_pool, head=model.forward_from_before_pool)
137
138 # we make a copy of the model just so deactivating gradients does not influence other detectors
139 model_gn = deepcopy(model)
140 model_gn.requires_grad_(False)
141 model_gn.fc.requires_grad_(True)
142 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
143
144 detectors["Entropy"] = Entropy(model)
145 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
146 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
147 detectors["Mahalanobis"] = Mahalanobis(model.features)
148
149 detectors["KLMatching"] = KLMatching(model)
150 detectors["SHE"] = SHE(model.features, model.fc)
151 detectors["MSP"] = MaxSoftmax(model)
152 detectors["EnergyBased"] = EnergyBased(model)
153 detectors["MaxLogit"] = MaxLogit(model)
154 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
155 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
156 detectors["RMD"] = RMD(model.features)
157
158 detectors["MultiMahalanobis"] = MultiMahalanobis(
159     [
160         model.conv1,
161         model.block1,
162         model.block2,
163         model.block3,
164         nn.Sequential(model.bn1, model.relu),
165     ]
166 )
167 detectors["Gram"] = Gram(
168     num_classes=10,
169     head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
170     feature_layers=[
171         model.conv1,
172         model.block1,
173         model.block2,
174         model.block3,
175         nn.Sequential(model.bn1, model.relu),
176     ],
177 )
178
179 # hyperparameters determined on Textures dataset
180 detectors["NAC-UE"] = NACUE(
181     model=model,
182     layers=[model.block2, model.block3, model.bn1],
183     m_bins=[200, 200, 200],
184     alpha=[150.0, 200.0, 250.0],
185     o_star=[25, 50, 100],
186     device=device,
187 )
188
189 # fit detectors to training data (some require this, some do not)
190 print(f"> Fitting {len(detectors)} detectors")
191 loader_in_train = DataLoader(
192     CIFAR10(root="data", train=True, transform=trans), batch_size=128, num_workers=12
193 )
194 for name, detector in detectors.items():
195     print(f"--> Fitting {name}")
196     detector.fit(loader_in_train, device=device)

Stage 3: Evaluate Detectors

200 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
201 results = []
202
203
204 with torch.no_grad():
205     for detector_name, detector in detectors.items():
206         print(f"> Evaluating {detector_name}")
207         for dataset_name, loader in datasets.items():
208             print(f"--> {dataset_name}")
209             metrics = OODMetrics()
210             for x, y in tqdm(loader, desc=dataset_name):
211                 metrics.update(detector(x.to(device)), y.to(device))
212
213         r = {"Detector": detector_name, "Dataset": dataset_name}
214         r.update(metrics.compute())
215         results.append(r)
216
217 # calculate mean scores over all datasets, use percent
218 df = pd.DataFrame(results)
219 mean_scores = (
220     df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
221 )
222 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery