Note
Go to the end to download the full example code.
PNML
Comparing PNML,
MaxSoftmax, and
EnergyBased on a small CIFAR-10 benchmark.
This example mirrors the other detector demos, but keeps fitting and evaluation on small subsets so it stays reasonably fast.
13 import logging
14
15 import torch
16 from torch.utils.data import DataLoader, Subset
17 from torchvision.datasets import CIFAR10
18
19 from pytorch_ood.dataset.img import Textures
20 from pytorch_ood.detector import EnergyBased, MaxSoftmax, PNML
21 from pytorch_ood.model import WideResNet
22 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
23
24 logging.basicConfig(level=logging.INFO)
25
26 fix_random_seed(123)
27
28 device = "cuda" if torch.cuda.is_available() else "cpu"
29 batch_size = 128
30
31 # Use subsets to keep this example lightweight while still running a real benchmark.
32 n_train = 1_000
33 n_in_test = 1_000
34 n_out_test = 1_000
Setup preprocessing and data
38 trans = WideResNet.transform_for("cifar10-pt")
39
40 dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
41 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
42 dataset_out_test = Textures(
43 root="data", download=True, transform=trans, target_transform=ToUnknown()
44 )
45
46 train_loader = DataLoader(
47 Subset(dataset_train, range(n_train)),
48 batch_size=batch_size,
49 shuffle=True,
50 )
51
52 test_loader = DataLoader(
53 Subset(dataset_in_test, range(n_in_test)) + Subset(dataset_out_test, range(n_out_test)),
54 batch_size=batch_size,
55 )
56
57
58 def evaluate(name, detector, loader):
59 metrics = OODMetrics()
60 for x, y in loader:
61 metrics.update(detector(x.to(device)), y)
62
63 print(f"{name}:")
64 print(metrics.compute())
Stage 1: Create DNN pre-trained on CIFAR-10
69 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()
Stage 2: Create and fit detectors
73 pnml = PNML(model.features, model.fc).to(device)
74 msp = MaxSoftmax(model).to(device)
75 energy = EnergyBased(model).to(device)
76
77 print("Fitting...")
78 pnml.fit(train_loader)
Stage 3: Evaluate detectors
82 print("Testing...")
83
84 evaluate("PNML", pnml, test_loader)
85 evaluate("MSP", msp, test_loader)
86 evaluate("Energy", energy, test_loader)