OpenMax

OpenMax was originally proposed for Open Set Recognition but can be adapted for Out-of-Distribution tasks.

Warning

OpenMax requires libmr to be installed, which is broken at the moment. You can only use it by installing cython and numpy, and libmr manually afterwards.

14 from torch.utils.data import DataLoader
15 from torchvision.datasets import CIFAR10
16
17 from pytorch_ood.dataset.img import Textures
18 from pytorch_ood.detector import OpenMax
19 from pytorch_ood.model import WideResNet
20 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
21
22 fix_random_seed(123)
23
24 device = "cuda:0"

Setup preprocessing and data

28 trans = WideResNet.transform_for("cifar10-pt")
29
30 dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
31 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
32 dataset_out_test = Textures(
33     root="data", download=True, transform=trans, target_transform=ToUnknown()
34 )
35
36 train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True)
37
38 # create data loaders
39 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128)

Stage 1: Create DNN pre-trained on CIFAR 10

43 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()

Stage 2: Create and Fit OpenMax

47 detector = OpenMax(model, tailsize=25, alpha=5, euclid_weight=0.5)
48 detector.to(device)
49 detector.fit(train_loader)

Stage 3: Evaluate Detectors

53 metrics = OODMetrics()
54
55 for x, y in test_loader:
56     metrics.update(detector(x.to(device)), y)
57
58 print(metrics.compute())

Gallery generated by Sphinx-Gallery