Note
Go to the end to download the full example code.
CIFAR 10
Example benchmark code for CIFAR10
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
Gram |
69.37 |
46.01 |
58.02 |
77.49 |
75.03 |
KLMatching |
88.48 |
39.83 |
72.29 |
91.33 |
57.84 |
SHE |
90.08 |
39.69 |
69.17 |
92.92 |
38.48 |
MSP |
91.41 |
37.07 |
86.36 |
92.42 |
29.93 |
Entropy |
92.03 |
35.90 |
86.70 |
93.47 |
29.75 |
Mahalanobis |
92.14 |
42.76 |
86.39 |
94.36 |
28.25 |
ODIN |
92.14 |
47.06 |
84.98 |
94.46 |
34.43 |
ViM |
92.32 |
40.22 |
85.76 |
94.93 |
29.49 |
Mahalanobis+ODIN |
92.60 |
42.76 |
86.81 |
95.08 |
27.11 |
DICE |
92.80 |
35.83 |
86.68 |
94.20 |
32.35 |
MaxLogit |
93.05 |
35.84 |
87.01 |
94.40 |
31.31 |
EnergyBased |
93.11 |
35.45 |
87.09 |
94.46 |
31.14 |
MultiMahalanobis |
93.43 |
44.60 |
86.70 |
96.48 |
22.95 |
RMD |
93.46 |
32.09 |
87.73 |
95.08 |
26.99 |
44 import pandas as pd # additional dependency, used here for convenience
45 import torch
46 from torch import nn
47 from torch.utils.data import DataLoader
48 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
49
50 from pytorch_ood.dataset.img import (
51 LSUNCrop,
52 LSUNResize,
53 Textures,
54 TinyImageNetCrop,
55 TinyImageNetResize,
56 Places365,
57 )
58 from pytorch_ood.detector import (
59 ODIN,
60 EnergyBased,
61 Entropy,
62 KLMatching,
63 Mahalanobis,
64 MaxLogit,
65 MaxSoftmax,
66 ViM,
67 RMD,
68 DICE,
69 SHE,
70 Gram,
71 MultiMahalanobis,
72 )
73 from pytorch_ood.model import WideResNet
74 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
75
76 device = "cuda:0"
77
78 fix_random_seed(123)
Setup preprocessing
82 trans = WideResNet.transform_for("cifar10-pt")
83 norm_std = WideResNet.norm_std_for("cifar10-pt")
Setup datasets
88 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
89
90 # create all OOD datasets
91 ood_datasets = [
92 Textures,
93 TinyImageNetCrop,
94 TinyImageNetResize,
95 LSUNCrop,
96 LSUNResize,
97 Places365,
98 CIFAR100,
99 MNIST,
100 FashionMNIST,
101 ]
102 datasets = {}
103 for ood_dataset in ood_datasets:
104 dataset_out_test = ood_dataset(
105 root="data", transform=trans, target_transform=ToUnknown(), download=True
106 )
107 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=512, num_workers=12)
108 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
112 print("STAGE 1: Creating a Model")
113 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
Stage 2: Create OOD detector
117 print("STAGE 2: Creating OOD Detectors")
118 detectors = {}
119 detectors["Entropy"] = Entropy(model)
120 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
121 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
122 detectors["Mahalanobis"] = Mahalanobis(model.features)
123
124 detectors["KLMatching"] = KLMatching(model)
125 detectors["SHE"] = SHE(model.features, model.fc)
126 detectors["MSP"] = MaxSoftmax(model)
127 detectors["EnergyBased"] = EnergyBased(model)
128 detectors["MaxLogit"] = MaxLogit(model)
129 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
130 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
131 detectors["RMD"] = RMD(model.features)
132
133 detectors["MultiMahalanobis"] = MultiMahalanobis(
134 [
135 model.conv1,
136 model.block1,
137 model.block2,
138 model.block3,
139 nn.Sequential(model.bn1, model.relu),
140 ]
141 )
142 detectors["Gram"] = Gram(
143 num_classes=10,
144 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
145 feature_layers=[
146 model.conv1,
147 model.block1,
148 model.block2,
149 model.block3,
150 nn.Sequential(model.bn1, model.relu),
151 ],
152 )
153
154
155 # fit detectors to training data (some require this, some do not)
156 print(f"> Fitting {len(detectors)} detectors")
157 loader_in_train = DataLoader(
158 CIFAR10(root="data", train=True, transform=trans), batch_size=512, num_workers=12
159 )
160 for name, detector in detectors.items():
161 print(f"--> Fitting {name}")
162 detector.fit(loader_in_train, device=device)
Stage 3: Evaluate Detectors
166 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
167 results = []
168
169 with torch.no_grad():
170 for detector_name, detector in detectors.items():
171 print(f"> Evaluating {detector_name}")
172 for dataset_name, loader in datasets.items():
173 print(f"--> {dataset_name}")
174 metrics = OODMetrics()
175 for x, y in loader:
176 metrics.update(detector(x.to(device)), y.to(device))
177
178 r = {"Detector": detector_name, "Dataset": dataset_name}
179 r.update(metrics.compute())
180 results.append(r)
181
182 # calculate mean scores over all datasets, use percent
183 df = pd.DataFrame(results)
184 mean_scores = (
185 df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
186 )
187 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))