CIFAR 100

The evaluation is the same as for CIFAR 10.

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

Gram

48.29

50.66

38.24

63.76

91.97

SHE

59.43

43.67

68.37

77.44

100.00

Mahalanobis

75.35

45.59

65.62

81.59

58.87

MSP

78.78

37.32

71.34

82.37

57.67

Mahalanobis+ODIN

79.24

44.89

68.69

84.58

55.91

KLMatching

79.88

41.07

68.23

83.53

60.02

ODIN

80.80

44.90

73.40

83.96

54.92

Entropy

81.19

38.44

73.08

84.61

56.49

ViM

81.73

43.50

72.91

85.87

49.86

RMD

83.23

39.43

74.56

86.94

50.55

MaxLogit

84.70

41.89

78.33

86.66

47.40

EnergyBased

85.00

41.89

78.69

86.88

46.70

MultiMahalanobis

85.33

45.93

77.84

89.51

39.25

DICE

85.35

41.84

78.99

87.32

46.17

43 import pandas as pd  # additional dependency, used here for convenience
44 import torch
45 from torch.utils.data import DataLoader
46 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST
47 from torch import nn
48
49 from pytorch_ood.dataset.img import (
50     LSUNCrop,
51     LSUNResize,
52     Textures,
53     TinyImageNetCrop,
54     TinyImageNetResize,
55     Places365,
56 )
57 from pytorch_ood.detector import (
58     ODIN,
59     EnergyBased,
60     Entropy,
61     GEN,
62     KLMatching,
63     Mahalanobis,
64     MahalanobisODIN,
65     MaxLogit,
66     MaxSoftmax,
67     ViM,
68     RMD,
69     DICE,
70     SHE,
71     Gram,
72     GMM,
73     MultiMahalanobis,
74     RankFeat,
75     fDBD,
76 )
77 from pytorch_ood.model import WideResNet
78 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
79
80 device = "cuda:0"
81
82 fix_random_seed(123)
83
84 # setup preprocessing
85 trans = WideResNet.transform_for("cifar100-pt")
86 norm_std = WideResNet.norm_std_for("cifar100-pt")

Setup datasets

 90 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True)
 91
 92 # create all OOD datasets
 93 ood_datasets = [
 94     Textures,
 95     TinyImageNetCrop,
 96     TinyImageNetResize,
 97     LSUNCrop,
 98     LSUNResize,
 99     Places365,
100     CIFAR10,
101     MNIST,
102     FashionMNIST,
103 ]
104 datasets = {}
105 for ood_dataset in ood_datasets:
106     dataset_out_test = ood_dataset(
107         root="data", transform=trans, target_transform=ToUnknown(), download=True
108     )
109     test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=256, num_workers=12)
110     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

114 print("STAGE 1: Creating a Model")
115 model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)
116
117 # Stage 2: Create OOD detector
118 print("STAGE 2: Creating OOD Detectors")
119 detectors = {}
120 detectors["Entropy"] = Entropy(model)
121 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
122 detectors["Mahalanobis+ODIN"] = MahalanobisODIN(model.features, norm_std=norm_std, eps=0.002)
123 detectors["Mahalanobis"] = Mahalanobis(model.features)
124 detectors["KLMatching"] = KLMatching(model)
125 detectors["SHE"] = SHE(model.features, model.fc)
126 detectors["MSP"] = MaxSoftmax(model)
127 detectors["EnergyBased"] = EnergyBased(model)
128 detectors["GEN"] = GEN(model)
129 detectors["GMM"] = GMM(model.features)
130 detectors["fDBD"] = fDBD(encoder=model.features, head=model.fc)
131 detectors["RankFeat"] = RankFeat(backbone=model.feature_maps, head=model.forward_feature_maps)
132 detectors["MaxLogit"] = MaxLogit(model)
133 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
134 detectors["DICE"] = DICE(encoder=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
135 detectors["RMD"] = RMD(model.features)
136 detectors["MultiMahalanobis"] = MultiMahalanobis(
137     [
138         model.conv1,
139         model.block1,
140         model.block2,
141         model.block3,
142         nn.Sequential(model.bn1, model.relu),
143     ]
144 )
145 detectors["Gram"] = Gram(
146     num_classes=100,
147     head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
148     feature_layers=[
149         model.conv1,
150         model.block1,
151         model.block2,
152         model.block3,
153         nn.Sequential(model.bn1, model.relu),
154     ],
155 )

Stage 2: fit detectors to training data (some require this, some do not)

159 print(f"> Fitting {len(detectors)} detectors")
160 loader_in_train = DataLoader(
161     CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12
162 )
163 for name, detector in detectors.items():
164     print(f"--> Fitting {name}")
165     detector.to(device)
166     detector.fit(loader_in_train)

Stage 3: Evaluate Detectors

170 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
171 results = []
172
173 with torch.no_grad():
174     for detector_name, detector in detectors.items():
175         print(f"> Evaluating {detector_name}")
176         for dataset_name, loader in datasets.items():
177             print(f"--> {dataset_name}")
178             metrics = OODMetrics()
179             for x, y in loader:
180                 metrics.update(detector(x.to(device)), y.to(device))
181
182             r = {"Detector": detector_name, "Dataset": dataset_name}
183             r.update(metrics.compute())
184             results.append(r)
185
186 # calculate mean scores over all datasets, use percent
187 df = pd.DataFrame(results)
188 mean_scores = (
189     df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
190 )
191 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery