Note
Go to the end to download the full example code.
OpenOOD - CIFAR10
Reproduces the OpenOOD benchmark for OOD detection, using the WideResNet model from the Hendrycks baseline paper.
Warning
This is currently incomplete, see CIFAR10-OpenOOD.
13 import pandas as pd # additional dependency, used here for convenience
14 import torch
15
16 from pytorch_ood.benchmark import CIFAR10_OpenOOD
17 from pytorch_ood.detector import MaxSoftmax, ReAct, ASH
18 from pytorch_ood.model import WideResNet
19 from pytorch_ood.utils import fix_random_seed
20
21 fix_random_seed(123)
22
23 device = "cuda:0"
24 loader_kwargs = {"batch_size": 64}
27 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
28 trans = WideResNet.transform_for("cifar10-pt")
29 norm_std = WideResNet.norm_std_for("cifar10-pt")
Just add more detectors here if you want to test more
33 detectors = {
34 "MSP": MaxSoftmax(model),
35 }
38 results = []
39 benchmark = CIFAR10_OpenOOD(root="data", transform=trans)
40
41 with torch.no_grad():
42 for detector_name, detector in detectors.items():
43 print(f"> Evaluating {detector_name}")
44 res = benchmark.evaluate(detector, loader_kwargs=loader_kwargs, device=device)
45 for r in res:
46 r.update({"Detector": detector_name})
47 results += res
48
49 df = pd.DataFrame(results)
50 print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))
This should produce the following table:
Dataset |
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|---|
CIFAR100 |
MSP |
87.82 |
40.69 |
88.42 |
85.20 |
43.09 |
TinyImageNet |
MSP |
86.99 |
40.65 |
86.48 |
85.07 |
51.52 |
MNIST |
MSP |
92.66 |
37.23 |
94.33 |
90.30 |
22.46 |
FashionMNIST |
MSP |
94.95 |
33.53 |
96.18 |
93.36 |
15.58 |
Textures |
MSP |
88.51 |
39.68 |
92.99 |
78.50 |
40.89 |
Places365 |
MSP |
88.24 |
39.93 |
71.17 |
95.61 |
44.63 |