Note
Go to the end to download the full example code.
Multi-Layer Mahalanobis
Running MultiMahalanobis on CIFAR 10.
8 import logging
9
10 from torch import nn
11 from torch.utils.data import DataLoader
12 from torchvision.datasets import CIFAR10
13
14 from pytorch_ood.dataset.img import Textures
15 from pytorch_ood.detector import MultiMahalanobis
16 from pytorch_ood.model import WideResNet
17 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
18
19 logging.basicConfig(level=logging.INFO)
20
21 fix_random_seed(123)
22
23 device = "cuda"
Setup preprocessing and data
27 trans = WideResNet.transform_for("cifar10-pt")
28
29 dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
30 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
31 dataset_out_test = Textures(
32 root="data", download=True, transform=trans, target_transform=ToUnknown()
33 )
34
35 train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True)
36
37 # create data loaders
38 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128)
Stage 1: Create DNN pre-trained on CIFAR 10
42 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()
43
44 layer1 = model.conv1
45 layer2 = model.block1
46 layer3 = model.block2
47 layer4 = model.block3
48
49
50 class MyLayer(nn.Module):
51 def __init__(self, bn1, relu):
52 super(MyLayer, self).__init__()
53 self.bn1 = bn1
54 self.relu = relu
55
56 def forward(self, x):
57 x = self.bn1(x)
58 x = self.relu(x)
59 return x
60
61
62 layer5 = MyLayer(model.bn1, model.relu)
Stage 2: Create and fit model
66 detector = MultiMahalanobis([layer1, layer2, layer3, layer4, layer5])
67
68 print("Fitting...")
69 detector.fit(train_loader, device=device)
Stage 3: Evaluate Detectors
73 print("Testing...")
74 metrics = OODMetrics()
75
76 for x, y in test_loader:
77 metrics.update(detector(x.to(device)), y)
78
79 print(metrics.compute())
This produces a table with the following output: {‘AUROC’: 0.9601144790649414, ‘AUPR-IN’: 0.9439688324928284, ‘AUPR-OUT’: 0.9745389223098755, ‘FPR95TPR’: 0.23440000414848328}