Note
Go to the end to download the full example code.
OpenOOD - CIFAR10
Reproduces the OpenOOD benchmark for OOD detection, using the WideResNet model from the Hendrycks baseline paper.
Warning
This is currently incomplete, see CIFAR10-OpenOOD.
12 import pandas as pd # additional dependency, used here for convenience
13 import torch
14
15 from pytorch_ood.benchmark import CIFAR10_OpenOOD
16 from pytorch_ood.detector import MaxSoftmax, ReAct, ASH
17 from pytorch_ood.model import WideResNet
18 from pytorch_ood.utils import fix_random_seed
19
20 fix_random_seed(123)
21
22 device = "cuda:0"
23 loader_kwargs = {"batch_size": 64}
26 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
27 trans = WideResNet.transform_for("cifar10-pt")
28 norm_std = WideResNet.norm_std_for("cifar10-pt")
Just add more detectors here if you want to test more
32 detectors = {
33 "MSP": MaxSoftmax(model),
34 }
37 results = []
38 benchmark = CIFAR10_OpenOOD(root="data", transform=trans)
39
40 with torch.no_grad():
41 for detector_name, detector in detectors.items():
42 print(f"> Evaluating {detector_name}")
43 res = benchmark.evaluate(detector, loader_kwargs=loader_kwargs, device=device)
44 for r in res:
45 r.update({"Detector": detector_name})
46 results += res
47
48 df = pd.DataFrame(results)
49 print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))
This should produce the following table:
Dataset |
Detector |
AUROC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
CIFAR100 |
MSP |
87.83 |
85.20 |
88.42 |
43.08 |
TinyImageNet |
MSP |
87.06 |
85.05 |
86.82 |
51.27 |
MNIST |
MSP |
92.66 |
90.29 |
94.33 |
22.47 |
FashionMNIST |
MSP |
94.95 |
93.36 |
96.18 |
15.59 |
Textures |
MSP |
88.51 |
78.50 |
92.99 |
40.86 |
Places365 |
MSP |
88.24 |
95.61 |
71.17 |
44.65 |