Note
Go to the end to download the full example code.
CIFAR 100
The evaluation is the same as for CIFAR 10.
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
Gram |
48.29 |
50.66 |
38.24 |
63.76 |
91.97 |
SHE |
59.43 |
43.67 |
68.37 |
77.44 |
100.00 |
Mahalanobis |
75.35 |
45.59 |
65.62 |
81.59 |
58.87 |
MSP |
78.78 |
37.32 |
71.34 |
82.37 |
57.67 |
Mahalanobis+ODIN |
79.24 |
44.89 |
68.69 |
84.58 |
55.91 |
KLMatching |
79.88 |
41.07 |
68.23 |
83.53 |
60.02 |
ODIN |
80.80 |
44.90 |
73.40 |
83.96 |
54.92 |
Entropy |
81.19 |
38.44 |
73.08 |
84.61 |
56.49 |
ViM |
81.73 |
43.50 |
72.91 |
85.87 |
49.86 |
RMD |
83.23 |
39.43 |
74.56 |
86.94 |
50.55 |
MaxLogit |
84.70 |
41.89 |
78.33 |
86.66 |
47.40 |
EnergyBased |
85.00 |
41.89 |
78.69 |
86.88 |
46.70 |
MultiMahalanobis |
85.33 |
45.93 |
77.84 |
89.51 |
39.25 |
DICE |
85.35 |
41.84 |
78.99 |
87.32 |
46.17 |
43 import pandas as pd # additional dependency, used here for convenience
44 import torch
45 from torch.utils.data import DataLoader
46 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST
47 from torch import nn
48
49 from pytorch_ood.dataset.img import (
50 LSUNCrop,
51 LSUNResize,
52 Textures,
53 TinyImageNetCrop,
54 TinyImageNetResize,
55 Places365,
56 )
57 from pytorch_ood.detector import (
58 ODIN,
59 EnergyBased,
60 Entropy,
61 KLMatching,
62 Mahalanobis,
63 MaxLogit,
64 MaxSoftmax,
65 ViM,
66 RMD,
67 DICE,
68 SHE,
69 Gram,
70 MultiMahalanobis,
71 )
72 from pytorch_ood.model import WideResNet
73 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
74
75 device = "cuda:0"
76
77 fix_random_seed(123)
78
79 # setup preprocessing
80 trans = WideResNet.transform_for("cifar100-pt")
81 norm_std = WideResNet.norm_std_for("cifar100-pt")
Setup datasets
85 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True)
86
87 # create all OOD datasets
88 ood_datasets = [
89 Textures,
90 TinyImageNetCrop,
91 TinyImageNetResize,
92 LSUNCrop,
93 LSUNResize,
94 Places365,
95 CIFAR10,
96 MNIST,
97 FashionMNIST,
98 ]
99 datasets = {}
100 for ood_dataset in ood_datasets:
101 dataset_out_test = ood_dataset(
102 root="data", transform=trans, target_transform=ToUnknown(), download=True
103 )
104 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=256, num_workers=12)
105 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
109 print("STAGE 1: Creating a Model")
110 model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)
111
112 # Stage 2: Create OOD detector
113 print("STAGE 2: Creating OOD Detectors")
114 detectors = {}
115 detectors["Entropy"] = Entropy(model)
116 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
117 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
118 detectors["Mahalanobis"] = Mahalanobis(model.features)
119 detectors["KLMatching"] = KLMatching(model)
120 detectors["SHE"] = SHE(model.features, model.fc)
121 detectors["MSP"] = MaxSoftmax(model)
122 detectors["EnergyBased"] = EnergyBased(model)
123 detectors["MaxLogit"] = MaxLogit(model)
124 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
125 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
126 detectors["RMD"] = RMD(model.features)
127 detectors["MultiMahalanobis"] = MultiMahalanobis(
128 [
129 model.conv1,
130 model.block1,
131 model.block2,
132 model.block3,
133 nn.Sequential(model.bn1, model.relu),
134 ]
135 )
136 detectors["Gram"] = Gram(
137 num_classes=100,
138 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
139 feature_layers=[
140 model.conv1,
141 model.block1,
142 model.block2,
143 model.block3,
144 nn.Sequential(model.bn1, model.relu),
145 ],
146 )
Stage 2: fit detectors to training data (some require this, some do not)
150 print(f"> Fitting {len(detectors)} detectors")
151 loader_in_train = DataLoader(
152 CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12
153 )
154 for name, detector in detectors.items():
155 print(f"--> Fitting {name}")
156 detector.fit(loader_in_train, device=device)
Stage 3: Evaluate Detectors
160 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
161 results = []
162
163 with torch.no_grad():
164 for detector_name, detector in detectors.items():
165 print(f"> Evaluating {detector_name}")
166 for dataset_name, loader in datasets.items():
167 print(f"--> {dataset_name}")
168 metrics = OODMetrics()
169 for x, y in loader:
170 metrics.update(detector(x.to(device)), y.to(device))
171
172 r = {"Detector": detector_name, "Dataset": dataset_name}
173 r.update(metrics.compute())
174 results.append(r)
175
176 # calculate mean scores over all datasets, use percent
177 df = pd.DataFrame(results)
178 mean_scores = (
179 df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
180 )
181 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))