Gram

Running Gram on CIFAR 10.

 9 import logging
10
11 from torch import nn
12 from torch.utils.data import DataLoader
13 from torchvision.datasets import CIFAR10
14
15 import torch.nn.functional as F
16 from pytorch_ood.dataset.img import Textures
17 from pytorch_ood.detector import Gram
18 from pytorch_ood.model import WideResNet
19 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
20
21 logging.basicConfig(level=logging.INFO)
22
23 fix_random_seed(123)
24
25 device = "cuda"

Setup preprocessing and data

29 trans = WideResNet.transform_for("cifar10-pt")
30
31 dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
32 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
33 dataset_out_test = Textures(
34     root="data", download=True, transform=trans, target_transform=ToUnknown()
35 )
36
37 train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True, num_workers=10)
38
39 # create data loaders
40 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=10)

Stage 1: Create DNN pre-trained on CIFAR 10

44 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()
45
46 layer1 = model.conv1
47 layer2 = model.block1
48 layer3 = model.block2
49 layer4 = model.block3
50 layer5 = nn.Sequential(model.bn1, model.relu)
51
52 head = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), model.fc)

Stage 2: Create and fit model

56 detector = Gram(
57     head,
58     [layer1, layer2, layer3, layer4, layer5],
59     num_classes=10,
60     num_poles_list=[1, 2, 3, 4, 5],
61 )
62
63 print("Fitting...")
64 detector.fit(train_loader, device=device)

Stage 3: Evaluate Detectors

69 print("Testing...")
70
71 metrics = OODMetrics()
72 for x, y in test_loader:
73     metrics.update(detector(x.to(device)), y)
74
75
76 print(metrics)

This produces a table with the following output:

81 # {'AUROC': 0.8175439834594727, 'AUTC': 0.4554872214794159, 'AUPR-IN': 0.8401336073875427, 'AUPR-OUT': 0.7695250511169434, 'FPR95TPR': 0.8087999820709229}

Gallery generated by Sphinx-Gallery