CIFAR10

Open Set Simulation on CIFAR 10

 9 import torch.nn
10 from torch.utils.data import DataLoader
11 from torchvision.datasets import CIFAR10
12 from torch.nn import CrossEntropyLoss
13 from tqdm import tqdm
14 from torchmetrics import Accuracy
15
16 from pytorch_ood.dataset.ossim import DynamicOSS
17 from pytorch_ood.model import WideResNet
18 from pytorch_ood.detector import MaxSoftmax
19 from pytorch_ood.utils import fix_random_seed, TargetMapping, OODMetrics, is_known
20
21 device = "cuda:0"
22 num_epochs = 10
23
24 fix_random_seed(12345)

Setup preprocessing

28 trans = WideResNet.transform_for("cifar10-pt")
29 norm_std = WideResNet.norm_std_for("cifar10-pt")

Setup datasets

33 dataset_1 = CIFAR10(root="data", train=True, transform=trans, download=True)
34 dataset_2 = CIFAR10(root="data", train=False, transform=trans, download=True)
35 dataset = dataset_1 + dataset_2

Create DNN with pre-trained on a downscaled version of the image net, excluding cifar images adjust it to output 7 logits

40 print("Creating a Model")
41 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
42 model.fc = torch.nn.Linear(model.fc.in_features, 7)
43 model.to(device)

Create open set simulation and dataloaders We use 3 classes as unknown unknown (OOD), and a data split of 90% train and 10% test

48 ossim = DynamicOSS(
49     dataset=dataset,
50     train_size=0.9,
51     val_size=0.0,
52     test_size=0.1,
53     kuc=0,
54     uuc_val=0,
55     uuc_test=3,
56     seed=1,
57 )
58 print(f"Known Classes: {ossim.kkc}")
59 print(f"Unknown Classes: {ossim.uuc}")
60
61 # create class remapping
62 class_mapping = TargetMapping(known=ossim.kkc, unknown=ossim.uuc)
63
64 train_loader = DataLoader(ossim.train_dataset(), batch_size=32, num_workers=12)
65 test_loader = DataLoader(ossim.test_dataset(), batch_size=32, num_workers=12)
66
67 criterion = CrossEntropyLoss()
68
69 opti = torch.optim.Adam(model.parameters(), lr=0.001)

Define function for testing

74 @torch.no_grad()
75 def test():
76     metrics = OODMetrics()
77     acc = Accuracy(task="multiclass", num_classes=7).to(device)
78     model.eval()
79
80     for x, y in tqdm(test_loader):
81         # do not forget to remap class labels
82         y = torch.tensor([class_mapping(i.item()) for i in y])
83
84         y = y.to(device)
85         x = x.to(device)
86
87         z = model(x)
88
89         metrics.update(MaxSoftmax.score(z), y)
90
91         known = is_known(y)
92         if known.any():
93             acc.update(z[known].argmax(dim=1), y[known])
94
95     print(metrics.compute())
96     print(acc.compute().item())

Start training

101 for epoch in range(num_epochs):
102     bar = tqdm(train_loader)
103     model.train()
104     loss_ema = None
105     for x, y in bar:
106         # do not forget to remap class labels
107         y = torch.tensor([class_mapping(i.item()) for i in y])
108
109         y = y.to(device)
110         x = x.to(device)
111
112         z = model(x)
113
114         loss = criterion(z, y)
115         opti.zero_grad()
116         loss.backward()
117         opti.step()
118
119         loss_ema = loss_ema * 0.95 + loss.item() * 0.05 if loss_ema is not None else loss.item()
120         bar.set_postfix_str(f"loss: {loss_ema:.2f}")
121
122     test()

Gallery generated by Sphinx-Gallery