Note
Go to the end to download the full example code.
CIFAR 100
The evaluation is the same as for CIFAR 10.
Detector |
AUROC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|
SHE |
59.43 |
77.44 |
68.37 |
100.00 |
Mahalanobis |
75.35 |
81.59 |
65.62 |
58.87 |
MSP |
78.78 |
82.37 |
71.34 |
57.67 |
Mahalanobis+ODIN |
79.24 |
84.58 |
68.69 |
55.91 |
KLMatching |
79.88 |
83.53 |
68.23 |
60.02 |
ODIN |
80.80 |
83.96 |
73.40 |
54.92 |
Entropy |
81.19 |
84.61 |
73.08 |
56.49 |
ViM |
81.73 |
85.87 |
72.91 |
49.86 |
RMD |
83.23 |
86.94 |
74.56 |
50.55 |
MaxLogit |
84.70 |
86.66 |
78.33 |
47.40 |
EnergyBased |
85.00 |
86.88 |
78.69 |
46.70 |
DICE |
85.35 |
87.32 |
78.99 |
46.17 |
37 import pandas as pd # additional dependency, used here for convenience
38 import torch
39 from torch.utils.data import DataLoader
40 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST
41
42 from pytorch_ood.dataset.img import (
43 LSUNCrop,
44 LSUNResize,
45 Textures,
46 TinyImageNetCrop,
47 TinyImageNetResize,
48 Places365,
49 TinyImageNet,
50 )
51 from pytorch_ood.detector import (
52 ODIN,
53 EnergyBased,
54 Entropy,
55 KLMatching,
56 Mahalanobis,
57 MaxLogit,
58 MaxSoftmax,
59 ViM,
60 RMD,
61 DICE,
62 SHE,
63 )
64 from pytorch_ood.model import WideResNet
65 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
66
67 device = "cuda:0"
68
69 fix_random_seed(123)
70
71 # setup preprocessing
72 trans = WideResNet.transform_for("cifar100-pt")
73 norm_std = WideResNet.norm_std_for("cifar100-pt")
Setup datasets
77 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True)
78
79 # create all OOD datasets
80 ood_datasets = [
81 Textures,
82 TinyImageNetCrop,
83 TinyImageNetResize,
84 LSUNCrop,
85 LSUNResize,
86 Places365,
87 CIFAR10,
88 MNIST,
89 FashionMNIST,
90 ]
91 datasets = {}
92 for ood_dataset in ood_datasets:
93 dataset_out_test = ood_dataset(
94 root="data", transform=trans, target_transform=ToUnknown(), download=True
95 )
96 test_loader = DataLoader(
97 dataset_in_test + dataset_out_test, batch_size=256, num_workers=12
98 )
99 datasets[ood_dataset.__name__] = test_loader
Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper
103 print("STAGE 1: Creating a Model")
104 model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)
105
106 # Stage 2: Create OOD detector
107 print("STAGE 2: Creating OOD Detectors")
108 detectors = {}
109 detectors["Entropy"] = Entropy(model)
110 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
111 detectors["Mahalanobis+ODIN"] = Mahalanobis(
112 model.features, norm_std=norm_std, eps=0.002
113 )
114 detectors["Mahalanobis"] = Mahalanobis(model.features)
115 detectors["KLMatching"] = KLMatching(model)
116 detectors["SHE"] = SHE(model.features, model.fc)
117 detectors["MSP"] = MaxSoftmax(model)
118 detectors["EnergyBased"] = EnergyBased(model)
119 detectors["MaxLogit"] = MaxLogit(model)
120 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
121 detectors["DICE"] = DICE(
122 model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65
123 )
124 detectors["RMD"] = RMD(model.features)
Stage 2: fit detectors to training data (some require this, some do not)
128 print(f"> Fitting {len(detectors)} detectors")
129 loader_in_train = DataLoader(
130 CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12
131 )
132 for name, detector in detectors.items():
133 print(f"--> Fitting {name}")
134 detector.fit(loader_in_train, device=device)
Stage 3: Evaluate Detectors
138 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
139 results = []
140
141 with torch.no_grad():
142 for detector_name, detector in detectors.items():
143 print(f"> Evaluating {detector_name}")
144 for dataset_name, loader in datasets.items():
145 print(f"--> {dataset_name}")
146 metrics = OODMetrics()
147 for x, y in loader:
148 metrics.update(detector(x.to(device)), y.to(device))
149
150 r = {"Detector": detector_name, "Dataset": dataset_name}
151 r.update(metrics.compute())
152 results.append(r)
153
154 # calculate mean scores over all datasets, use percent
155 df = pd.DataFrame(results)
156 mean_scores = df.groupby("Detector").mean() * 100
157 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))