Note
Go to the end to download the full example code.
Outlier Exposure
We train a model with Outlier Exposure on the CIFAR10.
We can use a model pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.
As outlier data, we use TinyImages300k, a cleaned version of the
TinyImages database, which contains random images scraped from the internet.
15 import torch
16 import torchvision.transforms as tvt
17 from torch.optim import Adam
18 from torch.utils.data import DataLoader
19 from torchvision.datasets import CIFAR10
20
21 from pytorch_ood.dataset.img import Textures, TinyImages300k
22 from pytorch_ood.detector import MaxSoftmax
23 from pytorch_ood.loss import OutlierExposureLoss
24 from pytorch_ood.model import WideResNet
25 from pytorch_ood.utils import OODMetrics, ToUnknown
26
27 torch.manual_seed(123)
28
29 # maximum number of epochs and training iterations
30 n_epochs = 10
31 device = "cuda:0"
Setup preprocessing and data
35 trans = tvt.Compose([tvt.Resize(size=(32, 32)), tvt.ToTensor()])
36
37 # setup ID training data
38 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
39
40 # setup OOD training data, use ToUnknown() to mark labels as OOD
41 # this way, outlier exposure can automatically decide if the training samples are ID or OOD
42 dataset_out_train = TinyImages300k(
43 root="data", download=True, transform=trans, target_transform=ToUnknown()
44 )
45
46 # setup ID test data
47 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
48
49 # setup OOD test data, use ToUnknown() to mark labels as OOD
50 dataset_out_test = Textures(
51 root="data", download=True, transform=trans, target_transform=ToUnknown()
52 )
53
54 # create data loaders
55 train_loader = DataLoader(dataset_in_train + dataset_out_train, batch_size=64, shuffle=True)
56 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=64)
Create DNN, pretrained on the imagenet excluding cifar10 classes
60 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
61 # we have to replace the final layer to account for the lower number of
62 # classes in the CIFAR10 dataset
63 model.fc = torch.nn.Linear(model.fc.in_features, 10)
64
65 model.to(device)
66
67 opti = Adam(model.parameters())
68 criterion = OutlierExposureLoss(alpha=0.5)
Define a function to test the model
73 def test():
74 softmax = MaxSoftmax(model)
75
76 metrics_softmax = OODMetrics()
77 model.eval()
78
79 with torch.no_grad():
80 for x, y in test_loader:
81 metrics_softmax.update(softmax(x.to(device)), y)
82
83 print(metrics_softmax.compute())
84 model.train()
Start training
89 for epoch in range(n_epochs):
90 print(f"Epoch {epoch}")
91 for x, y in train_loader:
92 logits = model(x.to(device))
93 loss = criterion(logits, y.to(device))
94 opti.zero_grad()
95 loss.backward()
96 opti.step()
97
98 test()
Output: Epoch 0 {‘AUROC’: 0.9438387155532837, ‘AUPR-IN’: 0.9145375490188599, ‘AUPR-OUT’: 0.9601001143455505, ‘FPR95TPR’: 0.3043999969959259} Epoch 1 {‘AUROC’: 0.9723063111305237, ‘AUPR-IN’: 0.9310603737831116, ‘AUPR-OUT’: 0.9854252338409424, ‘FPR95TPR’: 0.10440000146627426} Epoch 2 {‘AUROC’: 0.9726285338401794, ‘AUPR-IN’: 0.9353838562965393, ‘AUPR-OUT’: 0.9854604005813599, ‘FPR95TPR’: 0.10670000314712524} Epoch 3 {‘AUROC’: 0.9664252996444702, ‘AUPR-IN’: 0.9456377625465393, ‘AUPR-OUT’: 0.9795949459075928, ‘FPR95TPR’: 0.18469999730587006} Epoch 4 {‘AUROC’: 0.9807416200637817, ‘AUPR-IN’: 0.9635397791862488, ‘AUPR-OUT’: 0.9889929294586182, ‘FPR95TPR’: 0.09440000355243683} Epoch 5 {‘AUROC’: 0.9845513701438904, ‘AUPR-IN’: 0.9637761116027832, ‘AUPR-OUT’: 0.9917054772377014, ‘FPR95TPR’: 0.06310000270605087} Epoch 6 {‘AUROC’: 0.9830336570739746, ‘AUPR-IN’: 0.9557808637619019, ‘AUPR-OUT’: 0.9912922382354736, ‘FPR95TPR’: 0.05920000001788139} Epoch 7 {‘AUROC’: 0.9907971620559692, ‘AUPR-IN’: 0.9819102883338928, ‘AUPR-OUT’: 0.9949544668197632, ‘FPR95TPR’: 0.04010000079870224} Epoch 8 {‘AUROC’: 0.9874339699745178, ‘AUPR-IN’: 0.9670255184173584, ‘AUPR-OUT’: 0.9937484860420227, ‘FPR95TPR’: 0.041200000792741776} Epoch 9 {‘AUROC’: 0.9871684312820435, ‘AUPR-IN’: 0.9701670408248901, ‘AUPR-OUT’: 0.9931120276451111, ‘FPR95TPR’: 0.051899999380111694}