Note
Go to the end to download the full example code.
OpenOOD v1.5 - CIFAR10, Many Detectors
Evaluates a broad set of image-classification detectors on the CIFAR-10 OpenOOD benchmark using the benchmark interface and cached intermediate representations.
This example focuses on detectors that can run directly on the pretrained WideResNet used throughout the repository. It omits methods that require extra external dependencies or method-specific trained weights, such as OpenMax and WeightedEBO.
14 from collections import OrderedDict
15 from copy import deepcopy
16
17 import pandas as pd # additional dependency, used here for convenience
18 import torch
19 from torch import nn
20 from torch.utils.data import DataLoader, Subset
21
22 from pytorch_ood.benchmark import CIFAR10_OpenOOD
23 from pytorch_ood.detector import (
24 ASH,
25 DICE,
26 EnergyBased,
27 Entropy,
28 GEN,
29 GMM,
30 GradNorm,
31 GradNormKL,
32 Gram,
33 KLMatching,
34 KNN,
35 Mahalanobis,
36 MahalanobisODIN,
37 MaxLogit,
38 MaxSoftmax,
39 # MCD,
40 MultiMahalanobis,
41 NACUE,
42 NCI,
43 NNGuide,
44 ODIN,
45 PNML,
46 RMD,
47 RankFeat,
48 ReAct,
49 SHE,
50 TemperatureScaling,
51 ViM,
52 VRA,
53 fDBD,
54 )
55 from pytorch_ood.model import WideResNet
56 from pytorch_ood.utils import fix_random_seed
57
58 fix_random_seed(123)
59
60 device = "cuda:0" if torch.cuda.is_available() else "cpu"
61 loader_kwargs = {"batch_size": 128, "num_workers": 12}
62 cache_dir = "data/benchmark-cache"
63 cache_key = "cifar10-openood-wrn-cifar10-pt"
64 react_threshold = 1.0
65
66
67 def build_detectors(model, norm_std, react_threshold):
68 detectors = OrderedDict()
69
70 detectors["MSP"] = MaxSoftmax(model)
71 detectors["TemperatureScaling"] = TemperatureScaling(model)
72 detectors["Entropy"] = Entropy(model)
73 detectors["EnergyBased"] = EnergyBased(model)
74 detectors["MaxLogit"] = MaxLogit(model)
75 detectors["GEN"] = GEN(model)
76 detectors["KLMatching"] = KLMatching(model)
77 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
78 # detectors["MCD"] = MCD(model, samples=30, mode="var")
79
80 detectors["KNN"] = KNN(model.features)
81 detectors["GMM"] = GMM(model.features)
82 detectors["PNML"] = PNML(model.features, model.fc)
83 detectors["NNGuide"] = NNGuide(model.features, model.fc)
84 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
85 detectors["Mahalanobis"] = Mahalanobis(model.features)
86 detectors["Mahalanobis+ODIN"] = MahalanobisODIN(model.features, norm_std=norm_std, eps=0.002)
87 detectors["RMD"] = RMD(model.features)
88 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
89 detectors["NCI"] = NCI(encoder=model.features, head=model.fc, alpha=0.0)
90 detectors["SHE"] = SHE(model.features, model.fc)
91 detectors["DICE"] = DICE(encoder=model.features, w=model.fc.weight, b=model.fc.bias, p=65.0)
92 detectors["ReAct"] = ReAct(model.features, model.fc, threshold=react_threshold)
93 detectors["VRA"] = VRA(model.features, model.fc)
94
95 detectors["ASH"] = ASH(
96 backbone=model.feature_maps,
97 head=model.forward_feature_maps,
98 )
99 detectors["RankFeat"] = RankFeat(
100 backbone=model.feature_maps,
101 head=model.forward_feature_maps,
102 )
103
104 detectors["MultiMahalanobis"] = MultiMahalanobis(
105 [
106 model.conv1,
107 model.block1,
108 model.block2,
109 model.block3,
110 nn.Sequential(model.bn1, model.relu),
111 ]
112 )
113 detectors["Gram"] = Gram(
114 num_classes=10,
115 head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
116 feature_layers=[
117 model.conv1,
118 model.block1,
119 model.block2,
120 model.block3,
121 nn.Sequential(model.bn1, model.relu),
122 ],
123 )
124
125 model_gn = deepcopy(model)
126 model_gn.requires_grad_(False)
127 model_gn.fc.requires_grad_(True)
128 detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
129
130 model_gnkl = deepcopy(model)
131 model_gnkl.requires_grad_(False)
132 model_gnkl.fc.requires_grad_(True)
133 detectors["GradNormKL"] = GradNormKL(
134 model_gnkl, param_filter=lambda name: name.startswith("fc")
135 )
136
137 detectors["NAC-UE"] = NACUE(
138 model=model,
139 layers=[model.block2, model.block3, model.bn1],
140 m_bins=[200, 200, 200],
141 alpha=[150.0, 200.0, 250.0],
142 o_star=[25, 50, 100],
143 device=device,
144 )
145
146 return detectors
147
148
149 def fit_detectors(detectors, train_loader, calibration_loader):
150 for detector_name, detector in detectors.items():
151 if not getattr(detector, "requires_fit", False):
152 continue
153
154 fit_loader = (
155 calibration_loader
156 if detector_name in {"TemperatureScaling", "KLMatching"}
157 else train_loader
158 )
159 print(f"--> Fitting {detector_name}")
160 detector.to(device)
161 detector.fit(fit_loader)
165 print("STAGE 1: Creating model and benchmark")
166 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
167 trans = WideResNet.transform_for("cifar10-pt")
168 norm_std = WideResNet.norm_std_for("cifar10-pt")
169 benchmark = CIFAR10_OpenOOD(root="data", transform=trans)
170
171 train_dataset = benchmark.train_set()
172 train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
173 calibration_loader = DataLoader(
174 Subset(train_dataset, range(len(train_dataset) - 5000, len(train_dataset))),
175 shuffle=False,
176 **loader_kwargs,
177 )
178
179 print("STAGE 2: Creating and fitting detectors")
180 detectors = build_detectors(model=model, norm_std=norm_std, react_threshold=react_threshold)
181 fit_detectors(
182 detectors=detectors,
183 train_loader=train_loader,
184 calibration_loader=calibration_loader,
185 )
188 print("STAGE 3: Evaluating detectors")
189 results = []
190
191 for detector_name, detector in detectors.items():
192 print(f"> Evaluating {detector_name}")
193 res = benchmark.evaluate(
194 detector,
195 loader_kwargs=loader_kwargs,
196 device=device,
197 cache=True,
198 cache_dir=cache_dir,
199 cache_key=cache_key,
200 )
201 for row in res:
202 row.update({"Detector": detector_name})
203 results += res
204
205 df = pd.DataFrame(results)
206 print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))
207
208 print("\nMean scores:")
209 mean_scores = df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean()
210 print((mean_scores.sort_values("AUROC", ascending=False) * 100).to_csv(float_format="%.2f"))