Note
Go to the end to download the full example code.
CIFAR10
Open Set Simulation on CIFAR 10
9 import torch.nn
10 from torch.utils.data import DataLoader
11 from torchvision.datasets import CIFAR10
12 from torch.nn import CrossEntropyLoss
13 from tqdm import tqdm
14 from torchmetrics import Accuracy
15
16 from pytorch_ood.dataset.ossim import DynamicOSS
17 from pytorch_ood.model import WideResNet
18 from pytorch_ood.detector import MaxSoftmax
19 from pytorch_ood.utils import fix_random_seed, TargetMapping, OODMetrics, is_known
20
21 device = "cuda:0"
22 num_epochs = 10
23
24 fix_random_seed(12345)
Setup preprocessing
28 trans = WideResNet.transform_for("cifar10-pt")
29 norm_std = WideResNet.norm_std_for("cifar10-pt")
Setup datasets
33 dataset_1 = CIFAR10(root="data", train=True, transform=trans, download=True)
34 dataset_2 = CIFAR10(root="data", train=False, transform=trans, download=True)
35 dataset = dataset_1 + dataset_2
Create DNN with pre-trained on a downscaled version of the image net, excluding cifar images adjust it to output 7 logits
40 print("Creating a Model")
41 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
42 model.fc = torch.nn.Linear(model.fc.in_features, 7)
43 model.to(device)
Create open set simulation and dataloaders We use 3 classes as unknown unknown (OOD), and a data split of 90% train and 10% test
48 ossim = DynamicOSS(
49 dataset=dataset,
50 train_size=0.9,
51 val_size=0.0,
52 test_size=0.1,
53 kuc=0,
54 uuc_val=0,
55 uuc_test=3,
56 seed=1,
57 )
58 print(f"Known Classes: {ossim.kkc}")
59 print(f"Unknown Classes: {ossim.uuc}")
60
61 # create class remapping
62 class_mapping = TargetMapping(known=ossim.kkc, unknown=ossim.uuc)
63
64 train_loader = DataLoader(ossim.train_dataset(), batch_size=32, num_workers=12)
65 test_loader = DataLoader(ossim.test_dataset(), batch_size=32, num_workers=12)
66
67 criterion = CrossEntropyLoss()
68
69 opti = torch.optim.Adam(model.parameters(), lr=0.001)
Define function for testing
74 @torch.no_grad()
75 def test():
76 metrics = OODMetrics()
77 acc = Accuracy(task="multiclass", num_classes=7).to(device)
78 model.eval()
79
80 for x, y in tqdm(test_loader):
81 # do not forget to remap class labels
82 y = torch.tensor([class_mapping(i.item()) for i in y])
83
84 y = y.to(device)
85 x = x.to(device)
86
87 z = model(x)
88
89 metrics.update(MaxSoftmax.score(z), y)
90
91 known = is_known(y)
92 if known.any():
93 acc.update(z[known].argmax(dim=1), y[known])
94
95 print(metrics.compute())
96 print(acc.compute().item())
Start training
101 for epoch in range(num_epochs):
102 bar = tqdm(train_loader)
103 model.train()
104 loss_ema = None
105 for x, y in bar:
106 # do not forget to remap class labels
107 y = torch.tensor([class_mapping(i.item()) for i in y])
108
109 y = y.to(device)
110 x = x.to(device)
111
112 z = model(x)
113
114 loss = criterion(z, y)
115 opti.zero_grad()
116 loss.backward()
117 opti.step()
118
119 loss_ema = loss_ema * 0.95 + loss.item() * 0.05 if loss_ema is not None else loss.item()
120 bar.set_postfix_str(f"loss: {loss_ema:.2f}")
121
122 test()