Note
Go to the end to download the full example code.
OpenMax
OpenMax was originally proposed
for Open Set Recognition but can be adapted for Out-of-Distribution tasks.
Warning
OpenMax requires libmr to be installed, which is broken at the moment. You can only use it
by installing cython and numpy, and libmr manually afterwards.
13 from torch.utils.data import DataLoader
14 from torchvision.datasets import CIFAR10
15
16 from pytorch_ood.dataset.img import Textures
17 from pytorch_ood.detector import OpenMax
18 from pytorch_ood.model import WideResNet
19 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
20
21 fix_random_seed(123)
22
23 device = "cuda:0"
Setup preprocessing and data
27 trans = WideResNet.transform_for("cifar10-pt")
28
29 dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
30 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
31 dataset_out_test = Textures(
32 root="data", download=True, transform=trans, target_transform=ToUnknown()
33 )
34
35 train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True)
36
37 # create data loaders
38 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128)
Stage 1: Create DNN pre-trained on CIFAR 10
42 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()
Stage 2: Create and Fit OpenMax
46 detector = OpenMax(model, tailsize=25, alpha=5, euclid_weight=0.5)
47 detector.fit(train_loader, device=device)
Stage 3: Evaluate Detectors
51 metrics = OODMetrics()
52
53 for x, y in test_loader:
54 metrics.update(detector(x.to(device)), y)
55
56 print(metrics.compute())