Note
Go to the end to download the full example code.
GradNorm
Running GradNorm on CIFAR 10.
9 import logging
10
11 from torch.utils.data import DataLoader
12 from torchvision.datasets import CIFAR10
13
14 from pytorch_ood.dataset.img import Textures
15 from pytorch_ood.detector import GradNorm
16 from pytorch_ood.model import WideResNet
17 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
18
19 logging.basicConfig(level=logging.INFO)
20
21 fix_random_seed(123)
22
23 device = "cuda"
Setup preprocessing and data
27 trans = WideResNet.transform_for("cifar10-pt")
28
29 dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans)
30 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
31 dataset_out_test = Textures(
32 root="data", download=True, transform=trans, target_transform=ToUnknown()
33 )
34
35 train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True, num_workers=10)
36
37 # create data loaders
38 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128, num_workers=10)
Stage 1: Create DNN pre-trained on CIFAR 10
42 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()
43
44 model.requires_grad_(False)
45 model.fc.requires_grad_(True)
50 # Stage 2: Create detector, fitting is not required
51 detector = GradNorm(model, param_filter=lambda name: name.startswith("fc"))
Stage 3: Evaluate Detectors
56 print("Testing...")
57
58 metrics = OODMetrics()
59 for x, y in test_loader:
60 metrics.update(detector(x.to(device)), y)
61
62
63 print(metrics.compute())
This produces the following output: {‘AUROC’: 0.4999113380908966, ‘AUTC’: 0.5440057516098022, ‘AUPR-IN’: 0.31969308853149414, ‘AUPR-OUT’: 0.6802297830581665, ‘FPR95TPR’: 1.0}