Scone

We train a model with Energy Margin on the CIFAR10.

We can use a model pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation. As outlier data, we use TinyImages300k, a cleaned version of the TinyImages database, which contains random images scraped from the internet.

15 import torch
16 import torch.nn as nn
17 import torch.nn.functional as F
18 import torchvision.transforms as tvt
19 from torch import Tensor
20 from torch.optim import SGD
21 from torch.utils.data import DataLoader
22 from torchvision.datasets import CIFAR10
23
24 from typing import Callable
25
26 import numpy as np
27 from numpy import floating
28
29 from pytorch_ood.dataset.img import Textures, TinyImages300k
30 from pytorch_ood.detector import EnergyBased
31 from pytorch_ood.loss import EnergyMarginLoss
32 from pytorch_ood.model import WideResNet
33 from pytorch_ood.utils import OODMetrics, ToUnknown, to_np
34
35 torch.manual_seed(123)
36
37 # maximum number of epochs and training iterations
38 n_epochs = 100
39 device = "cuda:0"
40
41
42 def evaluate_classification_loss_training(
43     model: Callable[[Tensor], Tensor], train_loader_in
44 ) -> floating:
45     """
46     Evaluate classification loss on ID training dataset.
47
48     :param model: neural network to pass inputs to
49     :param train_loader_in: dataset to extract from
50     :return: ndarray with average loss
51     """
52     model.eval()
53     losses = []
54     for in_set in train_loader_in:
55         data = in_set[0]
56         target = in_set[1]
57
58         if torch.cuda.is_available():
59             data, target = data.cuda(), target.cuda()
60         # forward
61         y = model(data)
62
63         # in-distribution classification accuracy
64         loss_ce = F.cross_entropy(y, target, reduction="none")
65
66         losses.extend(list(to_np(loss_ce)))
67
68     avg_loss = np.mean(np.array(losses))
69     print("average loss fr classification {}".format(avg_loss))
70
71     return avg_loss

Setup preprocessing and data

 76 trans = tvt.Compose([tvt.Resize(size=(32, 32)), tvt.ToTensor()])
 77
 78 # setup ID training data
 79 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
 80
 81 # setup OOD training data, use ToUnknown() to mark labels as OOD
 82 # this way, outlier exposure can automatically decide if the training samples are ID or OOD
 83 dataset_out_train = TinyImages300k(
 84     root="data", download=True, transform=trans, target_transform=ToUnknown()
 85 )
 86
 87 # setup ID test data
 88 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
 89
 90 # TODO: Add Covariate-shifted Data
 91
 92 # setup OOD test data, use ToUnknown() to mark labels as OOD
 93 dataset_out_test = Textures(
 94     root="data", download=True, transform=trans, target_transform=ToUnknown()
 95 )
 96
 97 # create data loaders
 98 train_loader = DataLoader(dataset_in_train + dataset_out_train, batch_size=64, shuffle=True)
 99 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128)
100
101 train_loader_in = DataLoader(dataset_in_train, batch_size=128)

Create DNN, pretrained on the imagenet excluding cifar10 classes

105 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
106 # we have to replace the final layer to account for the lower number of
107 # classes in the CIFAR10 dataset
108 model.fc = torch.nn.Linear(model.fc.in_features, 10)
109
110 model.to(device)
111
112 logistic_regression = nn.Linear(1, 1)
113
114 logistic_regression.to(device)
115
116 opti = SGD(
117     list(model.parameters()) + list(logistic_regression.parameters()),
118     lr=0.0001,
119     momentum=0.9,
120     weight_decay=0.0005,
121     nesterov=True,
122 )
123
124 # Calculate ID Classification Loss Before Fine-Tuning
125 full_train_loss = evaluate_classification_loss_training(
126     model=model, train_loader_in=train_loader_in
127 )
128
129 criterion = EnergyMarginLoss(full_train_loss=full_train_loss)
130
131 scheduler = torch.optim.lr_scheduler.MultiStepLR(
132     optimizer=opti,
133     milestones=[int(n_epochs * 0.5), int(n_epochs * 0.75), int(n_epochs * 0.9)],
134     gamma=0.5,
135 )

Define a function to test the model

140 def test():
141     energy = EnergyBased(model)
142
143     metrics_energy = OODMetrics()
144     model.eval()
145
146     with torch.no_grad():
147         for x, y in test_loader:
148             metrics_energy.update(energy(x.to(device)), y)
149
150     print(metrics_energy.compute())
151     model.train()

Start training

156 for epoch in range(n_epochs):
157     print(f"Epoch {epoch}")
158     for x, y in train_loader:
159         logits = model(x.to(device))
160         loss = criterion(logits, y.to(device), logistic_regression)
161         opti.zero_grad()
162         loss.backward()
163         opti.step()
164     criterion.update_hyperparameters(
165         model=model,
166         train_loader_in=train_loader_in,
167         logistic_regression=logistic_regression,
168     )
169     test()
170     scheduler.step()

Gallery generated by Sphinx-Gallery