Gram

Running Gram on CIFAR 10.

 9 import logging
10
11 from torch import nn
12 from torch.utils.data import DataLoader
13 from torchvision.datasets import CIFAR10
14
15 from pytorch_ood.dataset.img import Textures
16 from pytorch_ood.detector import Gram
17 from pytorch_ood.model import WideResNet
18 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
19
20 logging.basicConfig(level=logging.INFO)
21
22 fix_random_seed(123)
23
24 device = "cuda"

Setup preprocessing and data

28 trans = WideResNet.transform_for("cifar10-pt")
29
30 dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
31 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
32 dataset_out_test = Textures(
33     root="data", download=True, transform=trans, target_transform=ToUnknown()
34 )
35
36 train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True, num_workers=10)
37
38 # create data loaders
39 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=10)

Stage 1: Create DNN pre-trained on CIFAR 10

43 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()
44
45 layer1 = model.conv1
46 layer2 = model.block1
47 layer3 = model.block2
48 layer4 = model.block3
49 layer5 = nn.Sequential(model.bn1, model.relu)
50
51 head = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc)

Stage 2: Create and fit detector

55 detector = Gram(
56     head,
57     [layer1, layer2, layer3, layer4, layer5],
58     num_classes=10,
59     num_poles_list=[1, 2, 3, 4, 5],
60 )
61
62 print("Fitting...")
63 detector.fit(train_loader, device=device)

Stage 3: Evaluate Detectors

68 print("Testing...")
69
70 metrics = OODMetrics()
71 for x, y in test_loader:
72     metrics.update(detector(x.to(device)), y)
73
74
75 print(metrics.compute())

This produces the following output: {‘AUROC’: 0.8175439834594727, ‘AUTC’: 0.4554872214794159, ‘AUPR-IN’: 0.8401336073875427, ‘AUPR-OUT’: 0.7695250511169434, ‘FPR95TPR’: 0.8087999820709229}

Gallery generated by Sphinx-Gallery