Note
Go to the end to download the full example code.
CIFAR10
Open Set Simulation on CIFAR 10
8 import torch.nn
9 from torch.utils.data import DataLoader
10 from torchvision.datasets import CIFAR10
11 from torch.nn import CrossEntropyLoss
12 from tqdm import tqdm
13 from torchmetrics import Accuracy
14
15 from pytorch_ood.dataset.ossim import DynamicOSS
16 from pytorch_ood.model import WideResNet
17 from pytorch_ood.detector import MaxSoftmax
18 from pytorch_ood.utils import fix_random_seed, TargetMapping, OODMetrics, is_known
19
20 device = "cuda:0"
21 num_epochs = 10
22
23 fix_random_seed(12345)
Setup preprocessing
27 trans = WideResNet.transform_for("cifar10-pt")
28 norm_std = WideResNet.norm_std_for("cifar10-pt")
Setup datasets
32 dataset_1 = CIFAR10(root="data", train=True, transform=trans, download=True)
33 dataset_2 = CIFAR10(root="data", train=False, transform=trans, download=True)
34 dataset = dataset_1 + dataset_2
Create DNN with pre-trained on a downscaled version of the image net, excluding cifar images adjust it to output 7 logits
39 print("Creating a Model")
40 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
41 model.fc = torch.nn.Linear(model.fc.in_features, 7)
42 model.to(device)
Create open set simulation and dataloaders We use 3 classes as unknown unknown (OOD), and a data split of 90% train and 10% test
47 ossim = DynamicOSS(
48 dataset=dataset,
49 train_size=0.9,
50 val_size=0.0,
51 test_size=0.1,
52 kuc=0,
53 uuc_val=0,
54 uuc_test=3,
55 seed=1,
56 )
57 print(f"Known Classes: {ossim.kkc}")
58 print(f"Unknown Classes: {ossim.uuc}")
59
60 # create class remapping
61 class_mapping = TargetMapping(known=ossim.kkc, unknown=ossim.uuc)
62
63 train_loader = DataLoader(ossim.train_dataset(), batch_size=32, num_workers=12)
64 test_loader = DataLoader(ossim.test_dataset(), batch_size=32, num_workers=12)
65
66 criterion = CrossEntropyLoss()
67
68 opti = torch.optim.Adam(model.parameters(), lr=0.001)
Define function for testing
73 @torch.no_grad()
74 def test():
75 metrics = OODMetrics()
76 acc = Accuracy(task="multiclass", num_classes=7).to(device)
77 model.eval()
78
79 for x, y in tqdm(test_loader):
80 # do not forget to remap class labels
81 y = torch.tensor([class_mapping(i.item()) for i in y])
82
83 y = y.to(device)
84 x = x.to(device)
85
86 z = model(x)
87
88 metrics.update(MaxSoftmax.score(z), y)
89
90 known = is_known(y)
91 if known.any():
92 acc.update(z[known].argmax(dim=1), y[known])
93
94 print(metrics.compute())
95 print(acc.compute().item())
Start training
100 for epoch in range(num_epochs):
101 bar = tqdm(train_loader)
102 model.train()
103 loss_ema = None
104 for x, y in bar:
105 # do not forget to remap class labels
106 y = torch.tensor([class_mapping(i.item()) for i in y])
107
108 y = y.to(device)
109 x = x.to(device)
110
111 z = model(x)
112
113 loss = criterion(z, y)
114 opti.zero_grad()
115 loss.backward()
116 opti.step()
117
118 loss_ema = (
119 loss_ema * 0.95 + loss.item() * 0.05
120 if loss_ema is not None
121 else loss.item()
122 )
123 bar.set_postfix_str(f"loss: {loss_ema:.2f}")
124
125 test()