OpenMax

OpenMax was originally proposed for Open Set Recognition but can be adapted for Out-of-Distribution tasks.

Warning

OpenMax requires libmr to be installed, which is broken at the moment. You can only use it by installing cython and numpy, and libmr manually afterwards.

13 from torch.utils.data import DataLoader
14 from torchvision.datasets import CIFAR10
15
16 from pytorch_ood.dataset.img import Textures
17 from pytorch_ood.detector import OpenMax
18 from pytorch_ood.model import WideResNet
19 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
20
21 fix_random_seed(123)
22
23 device = "cuda:0"

Setup preprocessing and data

27 trans = WideResNet.transform_for("cifar10-pt")
28
29 dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
30 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
31 dataset_out_test = Textures(
32     root="data", download=True, transform=trans, target_transform=ToUnknown()
33 )
34
35 train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True)
36
37 # create data loaders
38 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128)

Stage 1: Create DNN pre-trained on CIFAR 10

42 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()

Stage 2: Create and Fit OpenMax

46 detector = OpenMax(model, tailsize=25, alpha=5, euclid_weight=0.5)
47 detector.fit(train_loader, device=device)

Stage 3: Evaluate Detectors

51 metrics = OODMetrics()
52
53 for x, y in test_loader:
54     metrics.update(detector(x.to(device)), y)
55
56 print(metrics.compute())

Gallery generated by Sphinx-Gallery