OpenOOD - CIFAR10

Reproduces the OpenOOD benchmark for OOD detection, using the WideResNet model from the Hendrycks baseline paper.

Warning

This is currently incomplete, see CIFAR10-OpenOOD.

12 import pandas as pd  # additional dependency, used here for convenience
13 import torch
14
15 from pytorch_ood.benchmark import CIFAR10_OpenOOD
16 from pytorch_ood.detector import MaxSoftmax, ReAct, ASH
17 from pytorch_ood.model import WideResNet
18 from pytorch_ood.utils import fix_random_seed
19
20 fix_random_seed(123)
21
22 device = "cuda:0"
23 loader_kwargs = {"batch_size": 64}
26 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
27 trans = WideResNet.transform_for("cifar10-pt")
28 norm_std = WideResNet.norm_std_for("cifar10-pt")

Just add more detectors here if you want to test more

32 detectors = {
33     "MSP": MaxSoftmax(model),
34 }
37 results = []
38 benchmark = CIFAR10_OpenOOD(root="data", transform=trans)
39
40 with torch.no_grad():
41     for detector_name, detector in detectors.items():
42         print(f"> Evaluating {detector_name}")
43         res = benchmark.evaluate(detector, loader_kwargs=loader_kwargs, device=device)
44         for r in res:
45             r.update({"Detector": detector_name})
46         results += res
47
48 df = pd.DataFrame(results)
49 print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))

This should produce the following table:

Dataset

Detector

AUROC

AUPR-IN

AUPR-OUT

FPR95TPR

CIFAR100

MSP

87.83

85.20

88.42

43.08

TinyImageNet

MSP

87.06

85.05

86.82

51.27

MNIST

MSP

92.66

90.29

94.33

22.47

FashionMNIST

MSP

94.95

93.36

96.18

15.59

Textures

MSP

88.51

78.50

92.99

40.86

Places365

MSP

88.24

95.61

71.17

44.65

Gallery generated by Sphinx-Gallery