OpenOOD - CIFAR10

Reproduces the OpenOOD benchmark for OOD detection, using the WideResNet model from the Hendrycks baseline paper.

Warning

This is currently incomplete, see CIFAR10-OpenOOD.

13 import pandas as pd  # additional dependency, used here for convenience
14 import torch
15
16 from pytorch_ood.benchmark import CIFAR10_OpenOOD
17 from pytorch_ood.detector import MaxSoftmax, ReAct, ASH
18 from pytorch_ood.model import WideResNet
19 from pytorch_ood.utils import fix_random_seed
20
21 fix_random_seed(123)
22
23 device = "cuda:0"
24 loader_kwargs = {"batch_size": 64}
27 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
28 trans = WideResNet.transform_for("cifar10-pt")
29 norm_std = WideResNet.norm_std_for("cifar10-pt")

Just add more detectors here if you want to test more

33 detectors = {
34     "MSP": MaxSoftmax(model),
35 }
38 results = []
39 benchmark = CIFAR10_OpenOOD(root="data", transform=trans)
40
41 with torch.no_grad():
42     for detector_name, detector in detectors.items():
43         print(f"> Evaluating {detector_name}")
44         res = benchmark.evaluate(detector, loader_kwargs=loader_kwargs, device=device)
45         for r in res:
46             r.update({"Detector": detector_name})
47         results += res
48
49 df = pd.DataFrame(results)
50 print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))

This should produce the following table:

Dataset

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

CIFAR100

MSP

87.82

40.69

88.42

85.20

43.09

TinyImageNet

MSP

86.99

40.65

86.48

85.07

51.52

MNIST

MSP

92.66

37.23

94.33

90.30

22.46

FashionMNIST

MSP

94.95

33.53

96.18

93.36

15.58

Textures

MSP

88.51

39.68

92.99

78.50

40.89

Places365

MSP

88.24

39.93

71.17

95.61

44.63

Gallery generated by Sphinx-Gallery