OpenOOD v1.5 - CIFAR10

Reproduces the OpenOOD v1.5 benchmark for OOD detection on CIFAR-10, using the WideResNet model from the Hendrycks baseline paper.

11 import pandas as pd  # additional dependency, used here for convenience
12 import torch
13
14 from pytorch_ood.benchmark import CIFAR10_OpenOOD
15 from pytorch_ood.detector import MaxSoftmax, ReAct, ASH
16 from pytorch_ood.model import WideResNet
17 from pytorch_ood.utils import fix_random_seed
18
19 fix_random_seed(123)
20
21 device = "cuda:0"
22 loader_kwargs = {"batch_size": 64}
25 model = WideResNet(num_classes=10, pretrained="cifar10-pt").eval().to(device)
26 trans = WideResNet.transform_for("cifar10-pt")
27 norm_std = WideResNet.norm_std_for("cifar10-pt")

Just add more detectors here if you want to test more

31 detectors = {
32     "MSP": MaxSoftmax(model),
33 }
36 results = []
37 benchmark = CIFAR10_OpenOOD(root="data", transform=trans)
38
39 with torch.no_grad():
40     for detector_name, detector in detectors.items():
41         print(f"> Evaluating {detector_name}")
42         res = benchmark.evaluate(detector, loader_kwargs=loader_kwargs, device=device)
43         for r in res:
44             r.update({"Detector": detector_name})
45         results += res
46
47 df = pd.DataFrame(results)
48 print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))

This should produce a table with results for the following OOD datasets:

Near-OOD: * CIFAR100 * TinyImageNet

Far-OOD: * MNIST * SVHN * Textures * Places365

Gallery generated by Sphinx-Gallery