Note
Go to the end to download the full example code.
CIFAR 10
Example benchmark code for CIFAR10
Detector |
AUROC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|
KLMatching |
88.73 |
86.95 |
85.10 |
58.73 |
SHE |
90.67 |
89.63 |
77.72 |
37.41 |
MSP |
91.85 |
88.55 |
93.57 |
28.43 |
Entropy |
92.48 |
90.16 |
93.87 |
28.29 |
DICE |
92.63 |
91.07 |
93.30 |
32.78 |
MaxLogit |
93.06 |
91.44 |
93.74 |
31.18 |
EnergyBased |
93.10 |
91.51 |
93.78 |
31.05 |
ODIN |
93.20 |
92.12 |
93.94 |
31.65 |
RMD |
94.03 |
92.73 |
94.65 |
25.42 |
Mahalanobis |
94.06 |
92.42 |
95.17 |
22.79 |
ViM |
94.49 |
93.42 |
95.34 |
23.48 |
Mahalanobis+ODIN |
94.87 |
93.69 |
95.79 |
21.05 |
38 import pandas as pd # additional dependency, used here for convenience
39 import torch
40 from torch.utils.data import DataLoader
41 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
42
43 from pytorch_ood.dataset.img import (
44 LSUNCrop,
45 LSUNResize,
46 Textures,
47 TinyImageNetCrop,
48 TinyImageNetResize,
49 Places365,
50 )
51 from pytorch_ood.detector import (
52 ODIN,
53 EnergyBased,
54 Entropy,
55 KLMatching,
56 Mahalanobis,
57 MaxLogit,
58 MaxSoftmax,
59 ViM,
60 RMD,
61 DICE,
62 SHE,
63 )
64 from pytorch_ood.model import WideResNet
65 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
66
67 device = "cuda:0"
68
69 fix_random_seed(123)
Setup preprocessing
73 trans = WideResNet.transform_for("cifar10-pt")
74 norm_std = WideResNet.norm_std_for("cifar10-pt")
Setup datasets
79 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
80
81 # create all OOD datasets
82 ood_datasets = [
83 Textures,
84 TinyImageNetCrop,
85 TinyImageNetResize,
86 LSUNCrop,
87 LSUNResize,
88 Places365,
89 CIFAR100,
90 MNIST,
91 FashionMNIST,
92 ]
93 datasets = {}
94 for ood_dataset in ood_datasets:
95 dataset_out_test = ood_dataset(
96 root="data", transform=trans, target_transform=ToUnknown(), download=True
97 )
98 test_loader = DataLoader(
99 dataset_in_test + dataset_out_test, batch_size=256, num_workers=12
100 )
101 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
105 print("STAGE 1: Creating a Model")
106 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
Stage 2: Create OOD detector
110 print("STAGE 2: Creating OOD Detectors")
111 detectors = {}
112 detectors["Entropy"] = Entropy(model)
113 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
114 detectors["Mahalanobis+ODIN"] = Mahalanobis(
115 model.features, norm_std=norm_std, eps=0.002
116 )
117 detectors["Mahalanobis"] = Mahalanobis(model.features)
118 detectors["KLMatching"] = KLMatching(model)
119 detectors["SHE"] = SHE(model.features, model.fc)
120 detectors["MSP"] = MaxSoftmax(model)
121 detectors["EnergyBased"] = EnergyBased(model)
122 detectors["MaxLogit"] = MaxLogit(model)
123 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
124 detectors["DICE"] = DICE(
125 model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65
126 )
127 detectors["RMD"] = RMD(model.features)
128
129 # fit detectors to training data (some require this, some do not)
130 print(f"> Fitting {len(detectors)} detectors")
131 loader_in_train = DataLoader(
132 CIFAR10(root="data", train=True, transform=trans), batch_size=256, num_workers=12
133 )
134 for name, detector in detectors.items():
135 print(f"--> Fitting {name}")
136 detector.fit(loader_in_train, device=device)
Stage 3: Evaluate Detectors
140 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
141 results = []
142
143 with torch.no_grad():
144 for detector_name, detector in detectors.items():
145 print(f"> Evaluating {detector_name}")
146 for dataset_name, loader in datasets.items():
147 print(f"--> {dataset_name}")
148 metrics = OODMetrics()
149 for x, y in loader:
150 metrics.update(detector(x.to(device)), y.to(device))
151
152 r = {"Detector": detector_name, "Dataset": dataset_name}
153 r.update(metrics.compute())
154 results.append(r)
155
156 # calculate mean scores over all datasets, use percent
157 df = pd.DataFrame(results)
158 mean_scores = df.groupby("Detector").mean() * 100
159 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))