Outlier Exposure

We train a model with Outlier Exposure on the CIFAR10.

We can use a model pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation. As outlier data, we use TinyImages300k, a cleaned version of the TinyImages database, which contains random images scraped from the internet.

14 import torch
15 import torchvision.transforms as tvt
16 from torch.optim import Adam
17 from torch.utils.data import DataLoader
18 from torchvision.datasets import CIFAR10
19
20 from pytorch_ood.dataset.img import Textures, TinyImages300k
21 from pytorch_ood.detector import MaxSoftmax
22 from pytorch_ood.loss import OutlierExposureLoss
23 from pytorch_ood.model import WideResNet
24 from pytorch_ood.utils import OODMetrics, ToUnknown
25
26 torch.manual_seed(123)
27
28 # maximum number of epochs and training iterations
29 n_epochs = 10
30 device = "cuda:0"

Setup preprocessing and data

34 trans = tvt.Compose([tvt.Resize(size=(32, 32)), tvt.ToTensor()])
35
36 # setup IN training data
37 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
38
39 # setup OOD training data, use ToUnknown() to mark labels as OOD
40 # this way, outlier exposure can automatically decide if the training samples are IN or OOD
41 dataset_out_train = TinyImages300k(
42     root="data", download=True, transform=trans, target_transform=ToUnknown()
43 )
44
45 # setup IN test data
46 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
47
48 # setup OOD test data, use ToUnknown() to mark labels as OOD
49 dataset_out_test = Textures(
50     root="data", download=True, transform=trans, target_transform=ToUnknown()
51 )
52
53 # create data loaders
54 train_loader = DataLoader(dataset_in_train + dataset_out_train, batch_size=64, shuffle=True)
55 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=64)

Create DNN, pretrained on the imagenet excluding cifar10 classes

59 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
60 # we have to replace the final layer to account for the lower number of
61 # classes in the CIFAR10 dataset
62 model.fc = torch.nn.Linear(model.fc.in_features, 10)
63
64 model.to(device)
65
66 opti = Adam(model.parameters())
67 criterion = OutlierExposureLoss(alpha=0.5)

Define a function to test the model

72 def test():
73     softmax = MaxSoftmax(model)
74
75     metrics_softmax = OODMetrics()
76     model.eval()
77
78     with torch.no_grad():
79         for x, y in test_loader:
80             metrics_softmax.update(softmax(x.to(device)), y)
81
82     print(metrics_softmax.compute())
83     model.train()

Start training

88 for epoch in range(n_epochs):
89     print(f"Epoch {epoch}")
90     for x, y in train_loader:
91         logits = model(x.to(device))
92         loss = criterion(logits, y.to(device))
93         opti.zero_grad()
94         loss.backward()
95         opti.step()
96
97     test()

Output: Epoch 0 {‘AUROC’: 0.9438387155532837, ‘AUPR-IN’: 0.9145375490188599, ‘AUPR-OUT’: 0.9601001143455505, ‘FPR95TPR’: 0.3043999969959259} Epoch 1 {‘AUROC’: 0.9723063111305237, ‘AUPR-IN’: 0.9310603737831116, ‘AUPR-OUT’: 0.9854252338409424, ‘FPR95TPR’: 0.10440000146627426} Epoch 2 {‘AUROC’: 0.9726285338401794, ‘AUPR-IN’: 0.9353838562965393, ‘AUPR-OUT’: 0.9854604005813599, ‘FPR95TPR’: 0.10670000314712524} Epoch 3 {‘AUROC’: 0.9664252996444702, ‘AUPR-IN’: 0.9456377625465393, ‘AUPR-OUT’: 0.9795949459075928, ‘FPR95TPR’: 0.18469999730587006} Epoch 4 {‘AUROC’: 0.9807416200637817, ‘AUPR-IN’: 0.9635397791862488, ‘AUPR-OUT’: 0.9889929294586182, ‘FPR95TPR’: 0.09440000355243683} Epoch 5 {‘AUROC’: 0.9845513701438904, ‘AUPR-IN’: 0.9637761116027832, ‘AUPR-OUT’: 0.9917054772377014, ‘FPR95TPR’: 0.06310000270605087} Epoch 6 {‘AUROC’: 0.9830336570739746, ‘AUPR-IN’: 0.9557808637619019, ‘AUPR-OUT’: 0.9912922382354736, ‘FPR95TPR’: 0.05920000001788139} Epoch 7 {‘AUROC’: 0.9907971620559692, ‘AUPR-IN’: 0.9819102883338928, ‘AUPR-OUT’: 0.9949544668197632, ‘FPR95TPR’: 0.04010000079870224} Epoch 8 {‘AUROC’: 0.9874339699745178, ‘AUPR-IN’: 0.9670255184173584, ‘AUPR-OUT’: 0.9937484860420227, ‘FPR95TPR’: 0.041200000792741776} Epoch 9 {‘AUROC’: 0.9871684312820435, ‘AUPR-IN’: 0.9701670408248901, ‘AUPR-OUT’: 0.9931120276451111, ‘FPR95TPR’: 0.051899999380111694}

Gallery generated by Sphinx-Gallery