CIFAR 100

The evaluation is the same as for CIFAR 10.

Detector

AUROC

AUPR-IN

AUPR-OUT

FPR95TPR

SHE

59.43

77.44

68.37

100.00

Mahalanobis

75.35

81.59

65.62

58.87

MSP

78.78

82.37

71.34

57.67

Mahalanobis+ODIN

79.24

84.58

68.69

55.91

KLMatching

79.88

83.53

68.23

60.02

ODIN

80.80

83.96

73.40

54.92

Entropy

81.19

84.61

73.08

56.49

ViM

81.73

85.87

72.91

49.86

RMD

83.23

86.94

74.56

50.55

MaxLogit

84.70

86.66

78.33

47.40

EnergyBased

85.00

86.88

78.69

46.70

DICE

85.35

87.32

78.99

46.17

37 import pandas as pd  # additional dependency, used here for convenience
38 import torch
39 from torch.utils.data import DataLoader
40 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST
41
42 from pytorch_ood.dataset.img import (
43     LSUNCrop,
44     LSUNResize,
45     Textures,
46     TinyImageNetCrop,
47     TinyImageNetResize,
48     Places365,
49     TinyImageNet,
50 )
51 from pytorch_ood.detector import (
52     ODIN,
53     EnergyBased,
54     Entropy,
55     KLMatching,
56     Mahalanobis,
57     MaxLogit,
58     MaxSoftmax,
59     ViM,
60     RMD,
61     DICE,
62     SHE,
63 )
64 from pytorch_ood.model import WideResNet
65 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
66
67 device = "cuda:0"
68
69 fix_random_seed(123)
70
71 # setup preprocessing
72 trans = WideResNet.transform_for("cifar100-pt")
73 norm_std = WideResNet.norm_std_for("cifar100-pt")

Setup datasets

77 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True)
78
79 # create all OOD datasets
80 ood_datasets = [
81     Textures,
82     TinyImageNetCrop,
83     TinyImageNetResize,
84     LSUNCrop,
85     LSUNResize,
86     Places365,
87     CIFAR10,
88     MNIST,
89     FashionMNIST,
90 ]
91 datasets = {}
92 for ood_dataset in ood_datasets:
93     dataset_out_test = ood_dataset(
94         root="data", transform=trans, target_transform=ToUnknown(), download=True
95     )
96     test_loader = DataLoader(
97         dataset_in_test + dataset_out_test, batch_size=256, num_workers=12
98     )
99     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

103 print("STAGE 1: Creating a Model")
104 model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)
105
106 # Stage 2: Create OOD detector
107 print("STAGE 2: Creating OOD Detectors")
108 detectors = {}
109 detectors["Entropy"] = Entropy(model)
110 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
111 detectors["Mahalanobis+ODIN"] = Mahalanobis(
112     model.features, norm_std=norm_std, eps=0.002
113 )
114 detectors["Mahalanobis"] = Mahalanobis(model.features)
115 detectors["KLMatching"] = KLMatching(model)
116 detectors["SHE"] = SHE(model.features, model.fc)
117 detectors["MSP"] = MaxSoftmax(model)
118 detectors["EnergyBased"] = EnergyBased(model)
119 detectors["MaxLogit"] = MaxLogit(model)
120 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
121 detectors["DICE"] = DICE(
122     model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65
123 )
124 detectors["RMD"] = RMD(model.features)

Stage 2: fit detectors to training data (some require this, some do not)

128 print(f"> Fitting {len(detectors)} detectors")
129 loader_in_train = DataLoader(
130     CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12
131 )
132 for name, detector in detectors.items():
133     print(f"--> Fitting {name}")
134     detector.fit(loader_in_train, device=device)

Stage 3: Evaluate Detectors

138 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
139 results = []
140
141 with torch.no_grad():
142     for detector_name, detector in detectors.items():
143         print(f"> Evaluating {detector_name}")
144         for dataset_name, loader in datasets.items():
145             print(f"--> {dataset_name}")
146             metrics = OODMetrics()
147             for x, y in loader:
148                 metrics.update(detector(x.to(device)), y.to(device))
149
150             r = {"Detector": detector_name, "Dataset": dataset_name}
151             r.update(metrics.compute())
152             results.append(r)
153
154 # calculate mean scores over all datasets, use percent
155 df = pd.DataFrame(results)
156 mean_scores = df.groupby("Detector").mean() * 100
157 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery