CIFAR 10

Example benchmark code for CIFAR10

Detector

AUROC

AUPR-IN

AUPR-OUT

FPR95TPR

KLMatching

88.73

86.95

85.10

58.73

SHE

90.67

89.63

77.72

37.41

MSP

91.85

88.55

93.57

28.43

Entropy

92.48

90.16

93.87

28.29

DICE

92.63

91.07

93.30

32.78

MaxLogit

93.06

91.44

93.74

31.18

EnergyBased

93.10

91.51

93.78

31.05

ODIN

93.20

92.12

93.94

31.65

RMD

94.03

92.73

94.65

25.42

Mahalanobis

94.06

92.42

95.17

22.79

ViM

94.49

93.42

95.34

23.48

Mahalanobis+ODIN

94.87

93.69

95.79

21.05

38 import pandas as pd  # additional dependency, used here for convenience
39 import torch
40 from torch.utils.data import DataLoader
41 from torchvision.datasets import CIFAR10, CIFAR100, MNIST, FashionMNIST
42
43 from pytorch_ood.dataset.img import (
44     LSUNCrop,
45     LSUNResize,
46     Textures,
47     TinyImageNetCrop,
48     TinyImageNetResize,
49     Places365,
50 )
51 from pytorch_ood.detector import (
52     ODIN,
53     EnergyBased,
54     Entropy,
55     KLMatching,
56     Mahalanobis,
57     MaxLogit,
58     MaxSoftmax,
59     ViM,
60     RMD,
61     DICE,
62     SHE,
63 )
64 from pytorch_ood.model import WideResNet
65 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
66
67 device = "cuda:0"
68
69 fix_random_seed(123)

Setup preprocessing

73 trans = WideResNet.transform_for("cifar10-pt")
74 norm_std = WideResNet.norm_std_for("cifar10-pt")

Setup datasets

 79 dataset_in_test = CIFAR10(root="data", train=False, transform=trans, download=True)
 80
 81 # create all OOD datasets
 82 ood_datasets = [
 83     Textures,
 84     TinyImageNetCrop,
 85     TinyImageNetResize,
 86     LSUNCrop,
 87     LSUNResize,
 88     Places365,
 89     CIFAR100,
 90     MNIST,
 91     FashionMNIST,
 92 ]
 93 datasets = {}
 94 for ood_dataset in ood_datasets:
 95     dataset_out_test = ood_dataset(
 96         root="data", transform=trans, target_transform=ToUnknown(), download=True
 97     )
 98     test_loader = DataLoader(
 99         dataset_in_test + dataset_out_test, batch_size=256, num_workers=12
100     )
101     datasets[ood_dataset.__name__] = test_loader

Stage 1: Create DNN with pre-trained weights from the Hendrycks baseline paper

105 print("STAGE 1: Creating a Model")
106 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)

Stage 2: Create OOD detector

110 print("STAGE 2: Creating OOD Detectors")
111 detectors = {}
112 detectors["Entropy"] = Entropy(model)
113 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
114 detectors["Mahalanobis+ODIN"] = Mahalanobis(
115     model.features, norm_std=norm_std, eps=0.002
116 )
117 detectors["Mahalanobis"] = Mahalanobis(model.features)
118 detectors["KLMatching"] = KLMatching(model)
119 detectors["SHE"] = SHE(model.features, model.fc)
120 detectors["MSP"] = MaxSoftmax(model)
121 detectors["EnergyBased"] = EnergyBased(model)
122 detectors["MaxLogit"] = MaxLogit(model)
123 detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002)
124 detectors["DICE"] = DICE(
125     model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65
126 )
127 detectors["RMD"] = RMD(model.features)
128
129 # fit detectors to training data (some require this, some do not)
130 print(f"> Fitting {len(detectors)} detectors")
131 loader_in_train = DataLoader(
132     CIFAR10(root="data", train=True, transform=trans), batch_size=256, num_workers=12
133 )
134 for name, detector in detectors.items():
135     print(f"--> Fitting {name}")
136     detector.fit(loader_in_train, device=device)

Stage 3: Evaluate Detectors

140 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
141 results = []
142
143 with torch.no_grad():
144     for detector_name, detector in detectors.items():
145         print(f"> Evaluating {detector_name}")
146         for dataset_name, loader in datasets.items():
147             print(f"--> {dataset_name}")
148             metrics = OODMetrics()
149             for x, y in loader:
150                 metrics.update(detector(x.to(device)), y.to(device))
151
152             r = {"Detector": detector_name, "Dataset": dataset_name}
153             r.update(metrics.compute())
154             results.append(r)
155
156 # calculate mean scores over all datasets, use percent
157 df = pd.DataFrame(results)
158 mean_scores = df.groupby("Detector").mean() * 100
159 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery