Note
Go to the end to download the full example code.
CIFAR 100
The evaluation is the same as for CIFAR 10.
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
Gram |
48.29 |
50.66 |
38.24 |
63.76 |
91.97 |
SHE |
59.43 |
43.67 |
68.37 |
77.44 |
100.00 |
Mahalanobis |
75.35 |
45.59 |
65.62 |
81.59 |
58.87 |
MSP |
78.78 |
37.32 |
71.34 |
82.37 |
57.67 |
Mahalanobis+ODIN |
79.24 |
44.89 |
68.69 |
84.58 |
55.91 |
KLMatching |
79.88 |
41.07 |
68.23 |
83.53 |
60.02 |
ODIN |
80.80 |
44.90 |
73.40 |
83.96 |
54.92 |
Entropy |
81.19 |
38.44 |
73.08 |
84.61 |
56.49 |
ViM |
81.73 |
43.50 |
72.91 |
85.87 |
49.86 |
RMD |
83.23 |
39.43 |
74.56 |
86.94 |
50.55 |
MaxLogit |
84.70 |
41.89 |
78.33 |
86.66 |
47.40 |
EnergyBased |
85.00 |
41.89 |
78.69 |
86.88 |
46.70 |
MultiMahalanobis |
85.33 |
45.93 |
77.84 |
89.51 |
39.25 |
DICE |
85.35 |
41.84 |
78.99 |
87.32 |
46.17 |
43 import pandas as pd # additional dependency, used here for convenience
44 import torch
45 from torch.utils.data import DataLoader
46 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST
47 from torch import nn
48
49 from pytorch_ood.dataset.img import (
50 LSUNCrop,
51 LSUNResize,
52 Textures,
53 TinyImageNetCrop,
54 TinyImageNetResize,
55 Places365,
56 )
57 from pytorch_ood.detector import (
58 ODIN,
59 EnergyBased,
60 Entropy,
61 GEN,
62 KLMatching,
63 Mahalanobis,
64 MaxLogit,
65 MaxSoftmax,
66 ViM,
67 RMD,
68 DICE,
69 SHE,
70 Gram,
71 GMM,
72 MultiMahalanobis,
73 RankFeat,
74 fDBD,
75 )
76 from pytorch_ood.model import WideResNet
77 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
78
79 device = "cuda:0"
80
81 fix_random_seed(123)
82
83 # setup preprocessing
84 trans = WideResNet.transform_for("cifar100-pt")
85 norm_std = WideResNet.norm_std_for("cifar100-pt")
Setup datasets
89 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True)
90
91 # create all OOD datasets
92 ood_datasets = [
93 Textures,
94 TinyImageNetCrop,
95 TinyImageNetResize,
96 LSUNCrop,
97 LSUNResize,
98 Places365,
99 CIFAR10,
100 MNIST,
101 FashionMNIST,
102 ]
103 datasets = {}
104 for ood_dataset in ood_datasets:
105 dataset_out_test = ood_dataset(
106 root="data", transform=trans, target_transform=ToUnknown(), download=True
107 )
108 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=256, num_workers=12)
109 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
113 print("STAGE 1: Creating a Model")
114 model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)
115
116 # Stage 2: Create OOD detector
117 print("STAGE 2: Creating OOD Detectors")
118 detectors = {}
119 detectors["Entropy"] = Entropy(model)
120 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
121 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
122 detectors["Mahalanobis"] = Mahalanobis(model.features)
123 detectors["KLMatching"] = KLMatching(model)
124 detectors["SHE"] = SHE(model.features, model.fc)
125 detectors["MSP"] = MaxSoftmax(model)
126 detectors["EnergyBased"] = EnergyBased(model)
127 detectors["GEN"] = GEN(model)
128 detectors["GMM"] = GMM(model.features)
129 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
130 detectors["RankFeat"] = RankFeat(
131 backbone=model.features_before_pool, head=model.forward_from_before_pool
132 )
133 detectors["MaxLogit"] = MaxLogit(model)
134 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
135 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
136 detectors["RMD"] = RMD(model.features)
137 detectors["MultiMahalanobis"] = MultiMahalanobis(
138 [
139 model.conv1,
140 model.block1,
141 model.block2,
142 model.block3,
143 nn.Sequential(model.bn1, model.relu),
144 ]
145 )
146 detectors["Gram"] = Gram(
147 num_classes=100,
148 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
149 feature_layers=[
150 model.conv1,
151 model.block1,
152 model.block2,
153 model.block3,
154 nn.Sequential(model.bn1, model.relu),
155 ],
156 )
Stage 2: fit detectors to training data (some require this, some do not)
160 print(f"> Fitting {len(detectors)} detectors")
161 loader_in_train = DataLoader(
162 CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12
163 )
164 for name, detector in detectors.items():
165 print(f"--> Fitting {name}")
166 detector.to(device)
167 detector.fit(loader_in_train)
Stage 3: Evaluate Detectors
171 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
172 results = []
173
174 with torch.no_grad():
175 for detector_name, detector in detectors.items():
176 print(f"> Evaluating {detector_name}")
177 for dataset_name, loader in datasets.items():
178 print(f"--> {dataset_name}")
179 metrics = OODMetrics()
180 for x, y in loader:
181 metrics.update(detector(x.to(device)), y.to(device))
182
183 r = {"Detector": detector_name, "Dataset": dataset_name}
184 r.update(metrics.compute())
185 results.append(r)
186
187 # calculate mean scores over all datasets, use percent
188 df = pd.DataFrame(results)
189 mean_scores = (
190 df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
191 )
192 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))