Note
Go to the end to download the full example code.
Scone
We train a model with Energy Margin on the CIFAR10.
We can use a model pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.
As outlier data, we use TinyImages300k, a cleaned version of the
TinyImages database, which contains random images scraped from the internet.
14 import torch
15 import torch.nn as nn
16 import torch.nn.functional as F
17 import torchvision.transforms as tvt
18 from torch import Tensor
19 from torch.optim import SGD
20 from torch.utils.data import DataLoader
21 from torchvision.datasets import CIFAR10
22
23 from typing import Callable
24
25 import numpy as np
26 from numpy import floating
27
28 from pytorch_ood.dataset.img import Textures, TinyImages300k
29 from pytorch_ood.detector import EnergyBased
30 from pytorch_ood.loss import EnergyMarginLoss
31 from pytorch_ood.model import WideResNet
32 from pytorch_ood.utils import OODMetrics, ToUnknown, to_np
33
34 torch.manual_seed(123)
35
36 # maximum number of epochs and training iterations
37 n_epochs = 100
38 device = "cuda:0"
39
40 def evaluate_classification_loss_training(
41 model: Callable[[Tensor], Tensor], train_loader_in
42 ) -> floating:
43 """
44 Evaluate classification loss on ID training dataset.
45
46 :param model: neural network to pass inputs to
47 :param train_loader_in: dataset to extract from
48 :return: ndarray with average loss
49 """
50 model.eval()
51 losses = []
52 for in_set in train_loader_in:
53 data = in_set[0]
54 target = in_set[1]
55
56 if torch.cuda.is_available():
57 data, target = data.cuda(), target.cuda()
58 # forward
59 y = model(data)
60
61 # in-distribution classification accuracy
62 loss_ce = F.cross_entropy(y, target, reduction='none')
63
64 losses.extend(list(to_np(loss_ce)))
65
66 avg_loss = np.mean(np.array(losses))
67 print("average loss fr classification {}".format(avg_loss))
68
69 return avg_loss
Setup preprocessing and data
73 trans = tvt.Compose([tvt.Resize(size=(32, 32)), tvt.ToTensor()])
74
75 # setup IN training data
76 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
77
78 # setup OOD training data, use ToUnknown() to mark labels as OOD
79 # this way, outlier exposure can automatically decide if the training samples are IN or OOD
80 dataset_out_train = TinyImages300k(
81 root="data", download=True, transform=trans, target_transform=ToUnknown()
82 )
83
84 # setup IN test data
85 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
86
87 # TODO: Add Covariate-shifted Data
88
89 # setup OOD test data, use ToUnknown() to mark labels as OOD
90 dataset_out_test = Textures(
91 root="data", download=True, transform=trans, target_transform=ToUnknown()
92 )
93
94 # create data loaders
95 train_loader = DataLoader(
96 dataset_in_train + dataset_out_train, batch_size=64, shuffle=True
97 )
98 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=128)
99
100 train_loader_in = DataLoader(dataset_in_train, batch_size=128)
Create DNN, pretrained on the imagenet excluding cifar10 classes
104 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
105 # we have to replace the final layer to account for the lower number of
106 # classes in the CIFAR10 dataset
107 model.fc = torch.nn.Linear(model.fc.in_features, 10)
108
109 model.to(device)
110
111 logistic_regression = nn.Linear(1, 1)
112
113 logistic_regression.to(device)
114
115 opti = SGD(list(model.parameters()) + list(logistic_regression.parameters()), lr=0.0001, momentum=0.9, weight_decay=0.0005, nesterov=True)
116
117 # Calculate IN Classification Loss Before Fine-Tuning
118 full_train_loss = evaluate_classification_loss_training(model=model, train_loader_in=train_loader_in)
119
120 criterion = EnergyMarginLoss(full_train_loss=full_train_loss)
121
122 scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=opti,
123 milestones=[int(n_epochs*.5), int(n_epochs*.75), int(n_epochs*.9)], gamma=0.5)
Define a function to test the model
127 def test():
128 energy = EnergyBased(model)
129
130 metrics_energy = OODMetrics()
131 model.eval()
132
133 with torch.no_grad():
134 for x, y in test_loader:
135 metrics_energy.update(energy(x.to(device)), y)
136
137 print(metrics_energy.compute())
138 model.train()
Start training
143 for epoch in range(n_epochs):
144 print(f"Epoch {epoch}")
145 for x, y in train_loader:
146 logits = model(x.to(device))
147 loss = criterion(logits, y.to(device), logistic_regression)
148 opti.zero_grad()
149 loss.backward()
150 opti.step()
151 criterion.update_hyperparameters(model=model, train_loader_in=train_loader_in, logistic_regression=logistic_regression)
152 test()
153 scheduler.step()