Note
Go to the end to download the full example code.
Scone
We train a model with Energy Margin on the CIFAR10.
We can use a model pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.
As outlier data, we use TinyImages300k, a cleaned version of the
TinyImages database, which contains random images scraped from the internet.
15 import torch
16 import torch.nn as nn
17 import torch.nn.functional as F
18 import torchvision.transforms as tvt
19 from torch import Tensor
20 from torch.optim import SGD
21 from torch.utils.data import DataLoader
22 from torchvision.datasets import CIFAR10
23
24 from typing import Callable
25
26 import numpy as np
27 from numpy import floating
28
29 from pytorch_ood.dataset.img import Textures, TinyImages300k
30 from pytorch_ood.detector import EnergyBased
31 from pytorch_ood.loss import EnergyMarginLoss
32 from pytorch_ood.model import WideResNet
33 from pytorch_ood.utils import OODMetrics, ToUnknown, to_np
34
35 torch.manual_seed(123)
36
37 # maximum number of epochs and training iterations
38 n_epochs = 100
39 device = "cuda:0"
40
41
42 def evaluate_classification_loss_training(
43 model: Callable[[Tensor], Tensor], train_loader_in
44 ) -> floating:
45 """
46 Evaluate classification loss on ID training dataset.
47
48 :param model: neural network to pass inputs to
49 :param train_loader_in: dataset to extract from
50 :return: ndarray with average loss
51 """
52 model.eval()
53 losses = []
54 for in_set in train_loader_in:
55 data = in_set[0]
56 target = in_set[1]
57
58 if torch.cuda.is_available():
59 data, target = data.cuda(), target.cuda()
60 # forward
61 y = model(data)
62
63 # in-distribution classification accuracy
64 loss_ce = F.cross_entropy(y, target, reduction="none")
65
66 losses.extend(list(to_np(loss_ce)))
67
68 avg_loss = np.mean(np.array(losses))
69 print("average loss fr classification {}".format(avg_loss))
70
71 return avg_loss
Setup preprocessing and data
76 trans = tvt.Compose([tvt.Resize(size=(32, 32)), tvt.ToTensor()])
77
78 # setup ID training data
79 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
80
81 # setup OOD training data, use ToUnknown() to mark labels as OOD
82 # this way, outlier exposure can automatically decide if the training samples are ID or OOD
83 dataset_out_train = TinyImages300k(
84 root="data", download=True, transform=trans, target_transform=ToUnknown()
85 )
86
87 # setup ID test data
88 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
89
90 # TODO: Add Covariate-shifted Data
91
92 # setup OOD test data, use ToUnknown() to mark labels as OOD
93 dataset_out_test = Textures(
94 root="data", download=True, transform=trans, target_transform=ToUnknown()
95 )
96
97 # create data loaders
98 train_loader = DataLoader(dataset_in_train + dataset_out_train, batch_size=64, shuffle=True)
99 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128)
100
101 train_loader_in = DataLoader(dataset_in_train, batch_size=128)
Create DNN, pretrained on the imagenet excluding cifar10 classes
105 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
106 # we have to replace the final layer to account for the lower number of
107 # classes in the CIFAR10 dataset
108 model.fc = torch.nn.Linear(model.fc.in_features, 10)
109
110 model.to(device)
111
112 logistic_regression = nn.Linear(1, 1)
113
114 logistic_regression.to(device)
115
116 opti = SGD(
117 list(model.parameters()) + list(logistic_regression.parameters()),
118 lr=0.0001,
119 momentum=0.9,
120 weight_decay=0.0005,
121 nesterov=True,
122 )
123
124 # Calculate ID Classification Loss Before Fine-Tuning
125 full_train_loss = evaluate_classification_loss_training(
126 model=model, train_loader_in=train_loader_in
127 )
128
129 criterion = EnergyMarginLoss(full_train_loss=full_train_loss)
130
131 scheduler = torch.optim.lr_scheduler.MultiStepLR(
132 optimizer=opti,
133 milestones=[int(n_epochs * 0.5), int(n_epochs * 0.75), int(n_epochs * 0.9)],
134 gamma=0.5,
135 )
Define a function to test the model
140 def test():
141 energy = EnergyBased(model)
142
143 metrics_energy = OODMetrics()
144 model.eval()
145
146 with torch.no_grad():
147 for x, y in test_loader:
148 metrics_energy.update(energy(x.to(device)), y)
149
150 print(metrics_energy.compute())
151 model.train()
Start training
156 for epoch in range(n_epochs):
157 print(f"Epoch {epoch}")
158 for x, y in train_loader:
159 logits = model(x.to(device))
160 loss = criterion(logits, y.to(device), logistic_regression)
161 opti.zero_grad()
162 loss.backward()
163 opti.step()
164 criterion.update_hyperparameters(
165 model=model,
166 train_loader_in=train_loader_in,
167 logistic_regression=logistic_regression,
168 )
169 test()
170 scheduler.step()