OpenOOD v1.5 - CIFAR10, Many Detectors

Evaluates a broad set of image-classification detectors on the CIFAR-10 OpenOOD benchmark using the benchmark interface and cached intermediate representations.

This example focuses on detectors that can run directly on the pretrained WideResNet used throughout the repository. It omits methods that require extra external dependencies or method-specific trained weights, such as OpenMax and WeightedEBO.

 14 from collections import OrderedDict
 15 from copy import deepcopy
 16
 17 import pandas as pd  # additional dependency, used here for convenience
 18 import torch
 19 from torch import nn
 20 from torch.utils.data import DataLoader, Subset
 21
 22 from pytorch_ood.benchmark import CIFAR10_OpenOOD
 23 from pytorch_ood.detector import (
 24     ASH,
 25     DICE,
 26     EnergyBased,
 27     Entropy,
 28     GEN,
 29     GMM,
 30     GradNorm,
 31     GradNormKL,
 32     Gram,
 33     KLMatching,
 34     KNN,
 35     Mahalanobis,
 36     MahalanobisODIN,
 37     MaxLogit,
 38     MaxSoftmax,
 39     # MCD,
 40     MultiMahalanobis,
 41     NACUE,
 42     NCI,
 43     NNGuide,
 44     ODIN,
 45     PNML,
 46     RMD,
 47     RankFeat,
 48     ReAct,
 49     SHE,
 50     TemperatureScaling,
 51     ViM,
 52     VRA,
 53     fDBD,
 54 )
 55 from pytorch_ood.model import WideResNet
 56 from pytorch_ood.utils import fix_random_seed
 57
 58 fix_random_seed(123)
 59
 60 device = "cuda:0" if torch.cuda.is_available() else "cpu"
 61 loader_kwargs = {"batch_size": 128, "num_workers": 12}
 62 cache_dir = "data/benchmark-cache"
 63 cache_key = "cifar10-openood-wrn-cifar10-pt"
 64 react_threshold = 1.0
 65
 66
 67 def build_detectors(model, norm_std, react_threshold):
 68     detectors = OrderedDict()
 69
 70     detectors["MSP"] = MaxSoftmax(model)
 71     detectors["TemperatureScaling"] = TemperatureScaling(model)
 72     detectors["Entropy"] = Entropy(model)
 73     detectors["EnergyBased"] = EnergyBased(model)
 74     detectors["MaxLogit"] = MaxLogit(model)
 75     detectors["GEN"] = GEN(model)
 76     detectors["KLMatching"] = KLMatching(model)
 77     detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
 78     # detectors["MCD"] = MCD(model, samples=30, mode="var")
 79
 80     detectors["KNN"] = KNN(model.features)
 81     detectors["GMM"] = GMM(model.features)
 82     detectors["PNML"] = PNML(model.features, model.fc)
 83     detectors["NNGuide"] = NNGuide(model.features, model.fc)
 84     detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
 85     detectors["Mahalanobis"] = Mahalanobis(model.features)
 86     detectors["Mahalanobis+ODIN"] = MahalanobisODIN(model.features, norm_std=norm_std, eps=0.002)
 87     detectors["RMD"] = RMD(model.features)
 88     detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
 89     detectors["NCI"] = NCI(encoder=model.features, head=model.fc, alpha=0.0)
 90     detectors["SHE"] = SHE(model.features, model.fc)
 91     detectors["DICE"] = DICE(encoder=model.features, w=model.fc.weight, b=model.fc.bias, p=65.0)
 92     detectors["ReAct"] = ReAct(model.features, model.fc, threshold=react_threshold)
 93     detectors["VRA"] = VRA(model.features, model.fc)
 94
 95     detectors["ASH"] = ASH(
 96         backbone=model.feature_maps,
 97         head=model.forward_feature_maps,
 98     )
 99     detectors["RankFeat"] = RankFeat(
100         backbone=model.feature_maps,
101         head=model.forward_feature_maps,
102     )
103
104     detectors["MultiMahalanobis"] = MultiMahalanobis(
105         [
106             model.conv1,
107             model.block1,
108             model.block2,
109             model.block3,
110             nn.Sequential(model.bn1, model.relu),
111         ]
112     )
113     detectors["Gram"] = Gram(
114         num_classes=10,
115         head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
116         feature_layers=[
117             model.conv1,
118             model.block1,
119             model.block2,
120             model.block3,
121             nn.Sequential(model.bn1, model.relu),
122         ],
123     )
124
125     model_gn = deepcopy(model)
126     model_gn.requires_grad_(False)
127     model_gn.fc.requires_grad_(True)
128     detectors["GradNorm"] = GradNorm(model_gn, param_filter=lambda name: name.startswith("fc"))
129
130     model_gnkl = deepcopy(model)
131     model_gnkl.requires_grad_(False)
132     model_gnkl.fc.requires_grad_(True)
133     detectors["GradNormKL"] = GradNormKL(
134         model_gnkl, param_filter=lambda name: name.startswith("fc")
135     )
136
137     detectors["NAC-UE"] = NACUE(
138         model=model,
139         layers=[model.block2, model.block3, model.bn1],
140         m_bins=[200, 200, 200],
141         alpha=[150.0, 200.0, 250.0],
142         o_star=[25, 50, 100],
143         device=device,
144     )
145
146     return detectors
147
148
149 def fit_detectors(detectors, train_loader, calibration_loader):
150     for detector_name, detector in detectors.items():
151         if not getattr(detector, "requires_fit", False):
152             continue
153
154         fit_loader = (
155             calibration_loader
156             if detector_name in {"TemperatureScaling", "KLMatching"}
157             else train_loader
158         )
159         print(f"--> Fitting {detector_name}")
160         detector.to(device)
161         detector.fit(fit_loader)
165 print("STAGE 1: Creating model and benchmark")
166 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
167 trans = WideResNet.transform_for("cifar10-pt")
168 norm_std = WideResNet.norm_std_for("cifar10-pt")
169 benchmark = CIFAR10_OpenOOD(root="data", transform=trans)
170
171 train_dataset = benchmark.train_set()
172 train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs)
173 calibration_loader = DataLoader(
174     Subset(train_dataset, range(len(train_dataset) - 5000, len(train_dataset))),
175     shuffle=False,
176     **loader_kwargs,
177 )
178
179 print("STAGE 2: Creating and fitting detectors")
180 detectors = build_detectors(model=model, norm_std=norm_std, react_threshold=react_threshold)
181 fit_detectors(
182     detectors=detectors,
183     train_loader=train_loader,
184     calibration_loader=calibration_loader,
185 )
188 print("STAGE 3: Evaluating detectors")
189 results = []
190
191 for detector_name, detector in detectors.items():
192     print(f"> Evaluating {detector_name}")
193     res = benchmark.evaluate(
194         detector,
195         loader_kwargs=loader_kwargs,
196         device=device,
197         cache=True,
198         cache_dir=cache_dir,
199         cache_key=cache_key,
200     )
201     for row in res:
202         row.update({"Detector": detector_name})
203     results += res
204
205 df = pd.DataFrame(results)
206 print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))
207
208 print("\nMean scores:")
209 mean_scores = df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean()
210 print((mean_scores.sort_values("AUROC", ascending=False) * 100).to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery