CIFAR 100

The evaluation is the same as for CIFAR 10.

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

Gram

48.29

50.66

38.24

63.76

91.97

SHE

59.43

43.67

68.37

77.44

100.00

Mahalanobis

75.35

45.59

65.62

81.59

58.87

MSP

78.78

37.32

71.34

82.37

57.67

Mahalanobis+ODIN

79.24

44.89

68.69

84.58

55.91

KLMatching

79.88

41.07

68.23

83.53

60.02

ODIN

80.80

44.90

73.40

83.96

54.92

Entropy

81.19

38.44

73.08

84.61

56.49

ViM

81.73

43.50

72.91

85.87

49.86

RMD

83.23

39.43

74.56

86.94

50.55

MaxLogit

84.70

41.89

78.33

86.66

47.40

EnergyBased

85.00

41.89

78.69

86.88

46.70

MultiMahalanobis

85.33

45.93

77.84

89.51

39.25

DICE

85.35

41.84

78.99

87.32

46.17

43 import pandas as pd  # additional dependency, used here for convenience
44 import torch
45 from torch.utils.data import DataLoader
46 from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST
47 from torch import nn
48
49 from pytorch_ood.dataset.img import (
50     LSUNCrop,
51     LSUNResize,
52     Textures,
53     TinyImageNetCrop,
54     TinyImageNetResize,
55     Places365,
56 )
57 from pytorch_ood.detector import (
58     ODIN,
59     EnergyBased,
60     Entropy,
61     KLMatching,
62     Mahalanobis,
63     MaxLogit,
64     MaxSoftmax,
65     ViM,
66     RMD,
67     DICE,
68     SHE,
69     Gram,
70     MultiMahalanobis,
71 )
72 from pytorch_ood.model import WideResNet
73 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
74
75 device = "cuda:0"
76
77 fix_random_seed(123)
78
79 # setup preprocessing
80 trans = WideResNet.transform_for("cifar100-pt")
81 norm_std = WideResNet.norm_std_for("cifar100-pt")

Setup datasets

 85 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True)
 86
 87 # create all OOD datasets
 88 ood_datasets = [
 89     Textures,
 90     TinyImageNetCrop,
 91     TinyImageNetResize,
 92     LSUNCrop,
 93     LSUNResize,
 94     Places365,
 95     CIFAR10,
 96     MNIST,
 97     FashionMNIST,
 98 ]
 99 datasets = {}
100 for ood_dataset in ood_datasets:
101     dataset_out_test = ood_dataset(
102         root="data", transform=trans, target_transform=ToUnknown(), download=True
103     )
104     test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=256, num_workers=12)
105     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

109 print("STAGE 1: Creating a Model")
110 model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device)
111
112 # Stage 2: Create OOD detector
113 print("STAGE 2: Creating OOD Detectors")
114 detectors = {}
115 detectors["Entropy"] = Entropy(model)
116 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
117 detectors["Mahalanobis+ODIN"] = Mahalanobis(model.features, norm_std=norm_std, eps=0.002)
118 detectors["Mahalanobis"] = Mahalanobis(model.features)
119 detectors["KLMatching"] = KLMatching(model)
120 detectors["SHE"] = SHE(model.features, model.fc)
121 detectors["MSP"] = MaxSoftmax(model)
122 detectors["EnergyBased"] = EnergyBased(model)
123 detectors["MaxLogit"] = MaxLogit(model)
124 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
125 detectors["DICE"] = DICE(model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65)
126 detectors["RMD"] = RMD(model.features)
127 detectors["MultiMahalanobis"] = MultiMahalanobis(
128     [
129         model.conv1,
130         model.block1,
131         model.block2,
132         model.block3,
133         nn.Sequential(model.bn1, model.relu),
134     ]
135 )
136 detectors["Gram"] = Gram(
137     num_classes=100,
138     head=nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc),
139     feature_layers=[
140         model.conv1,
141         model.block1,
142         model.block2,
143         model.block3,
144         nn.Sequential(model.bn1, model.relu),
145     ],
146 )

Stage 2: fit detectors to training data (some require this, some do not)

150 print(f"> Fitting {len(detectors)} detectors")
151 loader_in_train = DataLoader(
152     CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12
153 )
154 for name, detector in detectors.items():
155     print(f"--> Fitting {name}")
156     detector.fit(loader_in_train, device=device)

Stage 3: Evaluate Detectors

160 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
161 results = []
162
163 with torch.no_grad():
164     for detector_name, detector in detectors.items():
165         print(f"> Evaluating {detector_name}")
166         for dataset_name, loader in datasets.items():
167             print(f"--> {dataset_name}")
168             metrics = OODMetrics()
169             for x, y in loader:
170                 metrics.update(detector(x.to(device)), y.to(device))
171
172             r = {"Detector": detector_name, "Dataset": dataset_name}
173             r.update(metrics.compute())
174             results.append(r)
175
176 # calculate mean scores over all datasets, use percent
177 df = pd.DataFrame(results)
178 mean_scores = (
179     df.groupby("Detector")[["AUROC", "AUTC", "AUPR-IN", "AUPR-OUT", "FPR95TPR"]].mean() * 100
180 )
181 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery