Scone

We train a model with Energy Margin on the CIFAR10.

We can use a model pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation. As outlier data, we use TinyImages300k, a cleaned version of the TinyImages database, which contains random images scraped from the internet.

14 import torch
15 import torch.nn as nn
16 import torch.nn.functional as F
17 import torchvision.transforms as tvt
18 from torch import Tensor
19 from torch.optim import SGD
20 from torch.utils.data import DataLoader
21 from torchvision.datasets import CIFAR10
22
23 from typing import Callable
24
25 import numpy as np
26 from numpy import floating
27
28 from pytorch_ood.dataset.img import Textures, TinyImages300k
29 from pytorch_ood.detector import EnergyBased
30 from pytorch_ood.loss import EnergyMarginLoss
31 from pytorch_ood.model import WideResNet
32 from pytorch_ood.utils import OODMetrics, ToUnknown, to_np
33
34 torch.manual_seed(123)
35
36 # maximum number of epochs and training iterations
37 n_epochs = 100
38 device = "cuda:0"
39
40 def evaluate_classification_loss_training(
41     model: Callable[[Tensor], Tensor], train_loader_in
42 ) -> floating:
43     """
44     Evaluate classification loss on ID training dataset.
45
46     :param model: neural network to pass inputs to
47     :param train_loader_in: dataset to extract from
48     :return: ndarray with average loss
49     """
50     model.eval()
51     losses = []
52     for in_set in train_loader_in:
53         data = in_set[0]
54         target = in_set[1]
55
56         if torch.cuda.is_available():
57             data, target = data.cuda(), target.cuda()
58         # forward
59         y = model(data)
60
61         # in-distribution classification accuracy
62         loss_ce = F.cross_entropy(y, target, reduction='none')
63
64         losses.extend(list(to_np(loss_ce)))
65
66     avg_loss = np.mean(np.array(losses))
67     print("average loss fr classification {}".format(avg_loss))
68
69     return avg_loss

Setup preprocessing and data

 73 trans = tvt.Compose([tvt.Resize(size=(32, 32)), tvt.ToTensor()])
 74
 75 # setup IN training data
 76 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
 77
 78 # setup OOD training data, use ToUnknown() to mark labels as OOD
 79 # this way, outlier exposure can automatically decide if the training samples are IN or OOD
 80 dataset_out_train = TinyImages300k(
 81     root="data", download=True, transform=trans, target_transform=ToUnknown()
 82 )
 83
 84 # setup IN test data
 85 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
 86
 87 # TODO: Add Covariate-shifted Data
 88
 89 # setup OOD test data, use ToUnknown() to mark labels as OOD
 90 dataset_out_test = Textures(
 91     root="data", download=True, transform=trans, target_transform=ToUnknown()
 92 )
 93
 94 # create data loaders
 95 train_loader = DataLoader(
 96     dataset_in_train + dataset_out_train, batch_size=64, shuffle=True
 97 )
 98 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128)
 99
100 train_loader_in = DataLoader(dataset_in_train, batch_size=128)

Create DNN, pretrained on the imagenet excluding cifar10 classes

104 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
105 # we have to replace the final layer to account for the lower number of
106 # classes in the CIFAR10 dataset
107 model.fc = torch.nn.Linear(model.fc.in_features, 10)
108
109 model.to(device)
110
111 logistic_regression = nn.Linear(1, 1)
112
113 logistic_regression.to(device)
114
115 opti = SGD(list(model.parameters()) + list(logistic_regression.parameters()), lr=0.0001, momentum=0.9, weight_decay=0.0005, nesterov=True)
116
117 # Calculate IN Classification Loss Before Fine-Tuning
118 full_train_loss = evaluate_classification_loss_training(model=model, train_loader_in=train_loader_in)
119
120 criterion = EnergyMarginLoss(full_train_loss=full_train_loss)
121
122 scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=opti,
123                                                  milestones=[int(n_epochs*.5), int(n_epochs*.75), int(n_epochs*.9)], gamma=0.5)

Define a function to test the model

127 def test():
128     energy = EnergyBased(model)
129
130     metrics_energy = OODMetrics()
131     model.eval()
132
133     with torch.no_grad():
134         for x, y in test_loader:
135             metrics_energy.update(energy(x.to(device)), y)
136
137     print(metrics_energy.compute())
138     model.train()

Start training

143 for epoch in range(n_epochs):
144     print(f"Epoch {epoch}")
145     for x, y in train_loader:
146         logits = model(x.to(device))
147         loss = criterion(logits, y.to(device), logistic_regression)
148         opti.zero_grad()
149         loss.backward()
150         opti.step()
151     criterion.update_hyperparameters(model=model, train_loader_in=train_loader_in, logistic_regression=logistic_regression)
152     test()
153     scheduler.step()

Gallery generated by Sphinx-Gallery