CIFAR 100

The evaluation is the same as for CIFAR 10.

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

Gram

48.29

50.66

38.24

63.76

91.97

SHE

59.43

43.67

68.37

77.44

100.00

Mahalanobis

75.35

45.59

65.62

81.59

58.87

MSP

78.78

37.32

71.34

82.37

57.67

Mahalanobis+ODIN

79.24

44.89

68.69

84.58

55.91

KLMatching

79.88

41.07

68.23

83.53

60.02

ODIN

80.80

44.90

73.40

83.96

54.92

Entropy

81.19

38.44

73.08

84.61

56.49

ViM

81.73

43.50

72.91

85.87

49.86

RMD

83.23

39.43

74.56

86.94

50.55

MaxLogit

84.70

41.89

78.33

86.66

47.40

EnergyBased

85.00

41.89

78.69

86.88

46.70

MultiMahalanobis

85.33

45.93

77.84

89.51

39.25

DICE

85.35

41.84

78.99

87.32

46.17

43 import pandas as pd  # additional dependency, used here for convenience
44 import torch
45 from torch.utils.data import DataLoader
46 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST
47 from torch import nn
48
49 from pytorch_ood.dataset.img import (
50     LSUNCrop,
51     LSUNResize,
52     Textures,
53     TinyImageNetCrop,
54     TinyImageNetResize,
55     Places365,
56 )
57 from pytorch_ood.detector import (
58     ODIN,
59     EnergyBased,
60     Entropy,
61     GEN,
62     KLMatching,
63     Mahalanobis,
64     MaxLogit,
65     MaxSoftmax,
66     ViM,
67     RMD,
68     DICE,
69     SHE,
70     Gram,
71     GMM,
72     MultiMahalanobis,
73     RankFeat,
74     fDBD,
75 )
76 from pytorch_ood.model import WideResNet
77 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
78
79 device = "cuda:0"
80
81 fix_random_seed(123)
82
83 # setup preprocessing
84 trans = WideResNet.transform_for("cifar100-pt")
85 norm_std = WideResNet.norm_std_for("cifar100-pt")

Setup datasets

 89 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True)
 90
 91 # create all OOD datasets
 92 ood_datasets = [
 93     Textures,
 94     TinyImageNetCrop,
 95     TinyImageNetResize,
 96     LSUNCrop,
 97     LSUNResize,
 98     Places365,
 99     CIFAR10,
100     MNIST,
101     FashionMNIST,
102 ]
103 datasets = {}
104 for ood_dataset in ood_datasets:
105     dataset_out_test = ood_dataset(
106         root="data", transform=trans, target_transform=ToUnknown(), download=True
107     )
108     test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=256, num_workers=12)
109     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

113 print("STAGE 1: Creating a Model")
114 model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)
115
116 # Stage 2: Create OOD detector
117 print("STAGE 2: Creating OOD Detectors")
118 detectors = {}
119 detectors["Entropy"] = Entropy(model)
120 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
121 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
122 detectors["Mahalanobis"] = Mahalanobis(model.features)
123 detectors["KLMatching"] = KLMatching(model)
124 detectors["SHE"] = SHE(model.features, model.fc)
125 detectors["MSP"] = MaxSoftmax(model)
126 detectors["EnergyBased"] = EnergyBased(model)
127 detectors["GEN"] = GEN(model)
128 detectors["GMM"] = GMM(model.features)
129 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
130 detectors["RankFeat"] = RankFeat(
131     backbone=model.features_before_pool, head=model.forward_from_before_pool
132 )
133 detectors["MaxLogit"] = MaxLogit(model)
134 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
135 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
136 detectors["RMD"] = RMD(model.features)
137 detectors["MultiMahalanobis"] = MultiMahalanobis(
138     [
139         model.conv1,
140         model.block1,
141         model.block2,
142         model.block3,
143         nn.Sequential(model.bn1, model.relu),
144     ]
145 )
146 detectors["Gram"] = Gram(
147     num_classes=100,
148     head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
149     feature_layers=[
150         model.conv1,
151         model.block1,
152         model.block2,
153         model.block3,
154         nn.Sequential(model.bn1, model.relu),
155     ],
156 )

Stage 2: fit detectors to training data (some require this, some do not)

160 print(f"> Fitting {len(detectors)} detectors")
161 loader_in_train = DataLoader(
162     CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12
163 )
164 for name, detector in detectors.items():
165     print(f"--> Fitting {name}")
166     detector.to(device)
167     detector.fit(loader_in_train)

Stage 3: Evaluate Detectors

171 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
172 results = []
173
174 with torch.no_grad():
175     for detector_name, detector in detectors.items():
176         print(f"> Evaluating {detector_name}")
177         for dataset_name, loader in datasets.items():
178             print(f"--> {dataset_name}")
179             metrics = OODMetrics()
180             for x, y in loader:
181                 metrics.update(detector(x.to(device)), y.to(device))
182
183             r = {"Detector": detector_name, "Dataset": dataset_name}
184             r.update(metrics.compute())
185             results.append(r)
186
187 # calculate mean scores over all datasets, use percent
188 df = pd.DataFrame(results)
189 mean_scores = (
190     df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
191 )
192 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery