CIFAR 10

Example benchmark code for CIFAR10

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

Gram

69.37

46.01

58.02

77.49

75.03

KLMatching

88.48

39.83

72.29

91.33

57.84

SHE

90.08

39.69

69.17

92.92

38.48

MSP

91.41

37.07

86.36

92.42

29.93

Entropy

92.03

35.90

86.70

93.47

29.75

Mahalanobis

92.14

42.76

86.39

94.36

28.25

ODIN

92.14

47.06

84.98

94.46

34.43

ViM

92.32

40.22

85.76

94.93

29.49

Mahalanobis+ODIN

92.60

42.76

86.81

95.08

27.11

DICE

92.80

35.83

86.68

94.20

32.35

MaxLogit

93.05

35.84

87.01

94.40

31.31

EnergyBased

93.11

35.45

87.09

94.46

31.14

MultiMahalanobis

93.43

44.60

86.70

96.48

22.95

RMD

93.46

32.09

87.73

95.08

26.99

44 import pandas as pd  # additional dependency, used here for convenience
45 import torch
46 from torch import nn
47 from torch.utils.data import DataLoader
48 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
49
50 from pytorch_ood.dataset.img import (
51     LSUNCrop,
52     LSUNResize,
53     Textures,
54     TinyImageNetCrop,
55     TinyImageNetResize,
56     Places365,
57 )
58 from pytorch_ood.detector import (
59     ODIN,
60     EnergyBased,
61     Entropy,
62     KLMatching,
63     Mahalanobis,
64     MaxLogit,
65     MaxSoftmax,
66     ViM,
67     RMD,
68     DICE,
69     SHE,
70     Gram,
71     MultiMahalanobis,
72 )
73 from pytorch_ood.model import WideResNet
74 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
75
76 device = "cuda:0"
77
78 fix_random_seed(123)

Setup preprocessing

82 trans = WideResNet.transform_for("cifar10-pt")
83 norm_std = WideResNet.norm_std_for("cifar10-pt")

Setup datasets

 88 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
 89
 90 # create all OOD datasets
 91 ood_datasets = [
 92     Textures,
 93     TinyImageNetCrop,
 94     TinyImageNetResize,
 95     LSUNCrop,
 96     LSUNResize,
 97     Places365,
 98     CIFAR100,
 99     MNIST,
100     FashionMNIST,
101 ]
102 datasets = {}
103 for ood_dataset in ood_datasets:
104     dataset_out_test = ood_dataset(
105         root="data", transform=trans, target_transform=ToUnknown(), download=True
106     )
107     test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=512, num_workers=12)
108     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

112 print("STAGE 1: Creating a Model")
113 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)

Stage 2: Create OOD detector

117 print("STAGE 2: Creating OOD Detectors")
118 detectors = {}
119 detectors["Entropy"] = Entropy(model)
120 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
121 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
122 detectors["Mahalanobis"] = Mahalanobis(model.features)
123
124 detectors["KLMatching"] = KLMatching(model)
125 detectors["SHE"] = SHE(model.features, model.fc)
126 detectors["MSP"] = MaxSoftmax(model)
127 detectors["EnergyBased"] = EnergyBased(model)
128 detectors["MaxLogit"] = MaxLogit(model)
129 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
130 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
131 detectors["RMD"] = RMD(model.features)
132
133 detectors["MultiMahalanobis"] = MultiMahalanobis(
134     [
135         model.conv1,
136         model.block1,
137         model.block2,
138         model.block3,
139         nn.Sequential(model.bn1, model.relu),
140     ]
141 )
142 detectors["Gram"] = Gram(
143     num_classes=10,
144     head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
145     feature_layers=[
146         model.conv1,
147         model.block1,
148         model.block2,
149         model.block3,
150         nn.Sequential(model.bn1, model.relu),
151     ],
152 )
153
154
155 # fit detectors to training data (some require this, some do not)
156 print(f"> Fitting {len(detectors)} detectors")
157 loader_in_train = DataLoader(
158     CIFAR10(root="data", train=True, transform=trans), batch_size=512, num_workers=12
159 )
160 for name, detector in detectors.items():
161     print(f"--> Fitting {name}")
162     detector.fit(loader_in_train, device=device)

Stage 3: Evaluate Detectors

166 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
167 results = []
168
169 with torch.no_grad():
170     for detector_name, detector in detectors.items():
171         print(f"> Evaluating {detector_name}")
172         for dataset_name, loader in datasets.items():
173             print(f"--> {dataset_name}")
174             metrics = OODMetrics()
175             for x, y in loader:
176                 metrics.update(detector(x.to(device)), y.to(device))
177
178             r = {"Detector": detector_name, "Dataset": dataset_name}
179             r.update(metrics.compute())
180             results.append(r)
181
182 # calculate mean scores over all datasets, use percent
183 df = pd.DataFrame(results)
184 mean_scores = (
185     df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
186 )
187 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery