Note
Go to the end to download the full example code.
CIFAR 100
The evaluation is the same as for CIFAR 10.
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
Gram |
48.29 |
50.66 |
38.24 |
63.76 |
91.97 |
SHE |
59.43 |
43.67 |
68.37 |
77.44 |
100.00 |
Mahalanobis |
75.35 |
45.59 |
65.62 |
81.59 |
58.87 |
MSP |
78.78 |
37.32 |
71.34 |
82.37 |
57.67 |
Mahalanobis+ODIN |
79.24 |
44.89 |
68.69 |
84.58 |
55.91 |
KLMatching |
79.88 |
41.07 |
68.23 |
83.53 |
60.02 |
ODIN |
80.80 |
44.90 |
73.40 |
83.96 |
54.92 |
Entropy |
81.19 |
38.44 |
73.08 |
84.61 |
56.49 |
ViM |
81.73 |
43.50 |
72.91 |
85.87 |
49.86 |
RMD |
83.23 |
39.43 |
74.56 |
86.94 |
50.55 |
MaxLogit |
84.70 |
41.89 |
78.33 |
86.66 |
47.40 |
EnergyBased |
85.00 |
41.89 |
78.69 |
86.88 |
46.70 |
MultiMahalanobis |
85.33 |
45.93 |
77.84 |
89.51 |
39.25 |
DICE |
85.35 |
41.84 |
78.99 |
87.32 |
46.17 |
43 import pandas as pd # additional dependency, used here for convenience
44 import torch
45 from torch.utils.data import DataLoader
46 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST
47 from torch import nn
48
49 from pytorch_ood.dataset.img import (
50 LSUNCrop,
51 LSUNResize,
52 Textures,
53 TinyImageNetCrop,
54 TinyImageNetResize,
55 Places365,
56 )
57 from pytorch_ood.detector import (
58 ODIN,
59 EnergyBased,
60 Entropy,
61 GEN,
62 KLMatching,
63 Mahalanobis,
64 MahalanobisODIN,
65 MaxLogit,
66 MaxSoftmax,
67 ViM,
68 RMD,
69 DICE,
70 SHE,
71 Gram,
72 GMM,
73 MultiMahalanobis,
74 RankFeat,
75 fDBD,
76 )
77 from pytorch_ood.model import WideResNet
78 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
79
80 device = "cuda:0"
81
82 fix_random_seed(123)
83
84 # setup preprocessing
85 trans = WideResNet.transform_for("cifar100-pt")
86 norm_std = WideResNet.norm_std_for("cifar100-pt")
Setup datasets
90 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True)
91
92 # create all OOD datasets
93 ood_datasets = [
94 Textures,
95 TinyImageNetCrop,
96 TinyImageNetResize,
97 LSUNCrop,
98 LSUNResize,
99 Places365,
100 CIFAR10,
101 MNIST,
102 FashionMNIST,
103 ]
104 datasets = {}
105 for ood_dataset in ood_datasets:
106 dataset_out_test = ood_dataset(
107 root="data", transform=trans, target_transform=ToUnknown(), download=True
108 )
109 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=256, num_workers=12)
110 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
114 print("STAGE 1: Creating a Model")
115 model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)
116
117 # Stage 2: Create OOD detector
118 print("STAGE 2: Creating OOD Detectors")
119 detectors = {}
120 detectors["Entropy"] = Entropy(model)
121 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
122 detectors["Mahalanobis+ODIN"] = MahalanobisODIN(model.features, norm_std=norm_std, eps=0.002)
123 detectors["Mahalanobis"] = Mahalanobis(model.features)
124 detectors["KLMatching"] = KLMatching(model)
125 detectors["SHE"] = SHE(model.features, model.fc)
126 detectors["MSP"] = MaxSoftmax(model)
127 detectors["EnergyBased"] = EnergyBased(model)
128 detectors["GEN"] = GEN(model)
129 detectors["GMM"] = GMM(model.features)
130 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
131 detectors["RankFeat"] = RankFeat(backbone=model.feature_maps, head=model.forward_feature_maps)
132 detectors["MaxLogit"] = MaxLogit(model)
133 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
134 detectors["DICE"] = DICE(encoder=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
135 detectors["RMD"] = RMD(model.features)
136 detectors["MultiMahalanobis"] = MultiMahalanobis(
137 [
138 model.conv1,
139 model.block1,
140 model.block2,
141 model.block3,
142 nn.Sequential(model.bn1, model.relu),
143 ]
144 )
145 detectors["Gram"] = Gram(
146 num_classes=100,
147 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
148 feature_layers=[
149 model.conv1,
150 model.block1,
151 model.block2,
152 model.block3,
153 nn.Sequential(model.bn1, model.relu),
154 ],
155 )
Stage 2: fit detectors to training data (some require this, some do not)
159 print(f"> Fitting {len(detectors)} detectors")
160 loader_in_train = DataLoader(
161 CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12
162 )
163 for name, detector in detectors.items():
164 print(f"--> Fitting {name}")
165 detector.to(device)
166 detector.fit(loader_in_train)
Stage 3: Evaluate Detectors
170 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
171 results = []
172
173 with torch.no_grad():
174 for detector_name, detector in detectors.items():
175 print(f"> Evaluating {detector_name}")
176 for dataset_name, loader in datasets.items():
177 print(f"--> {dataset_name}")
178 metrics = OODMetrics()
179 for x, y in loader:
180 metrics.update(detector(x.to(device)), y.to(device))
181
182 r = {"Detector": detector_name, "Dataset": dataset_name}
183 r.update(metrics.compute())
184 results.append(r)
185
186 # calculate mean scores over all datasets, use percent
187 df = pd.DataFrame(results)
188 mean_scores = (
189 df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
190 )
191 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))