Virtual Outlier Synthesizer Loss

We train a model with Virtual Outlier Synthesizer Loss on the CIFAR10.

We then use the WeightedEBO OOD detector.

We can use a model pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.

12 import numpy as np
13 import torch
14 from torch.optim.lr_scheduler import CosineAnnealingLR
15 from torch.utils.data import DataLoader
16 from torchvision.datasets import CIFAR10
17
18 from pytorch_ood.dataset.img import Textures
19 from pytorch_ood.detector import EnergyBased, WeightedEBO
20 from pytorch_ood.loss import VirtualOutlierSynthesizingRegLoss
21 from pytorch_ood.model import WideResNet
22 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
23
24 device = "cuda:0"
25 batch_size = 128
26 num_epochs = 10
27 lr = 0.1
28 num_classes = 10
29
30 fix_random_seed(12345)
31 g = torch.Generator()
32 g.manual_seed(0)

Setup datasets, train on cifar.

37 trans = WideResNet.transform_for("cifar10-pt")
38
39 dataset = CIFAR10(root="data", train=True, transform=trans, download=True)
40
41 # setup ID test data
42 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
43
44 # setup OOD test data, use ToUnknown() to mark labels as OOD
45 dataset_out_test = Textures(
46     root="data", download=True, transform=trans, target_transform=ToUnknown()
47 )
48
49
50 loader = DataLoader(
51     dataset,
52     batch_size=batch_size,
53     shuffle=True,
54     num_workers=10,
55     worker_init_fn=fix_random_seed,
56     generator=g,
57 )

Setup model

61 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device)

Create neural network functions (layers)

65 phi = torch.nn.Linear(1, 2).to(device)
66 weights_energy = torch.nn.Linear(num_classes, 1).to(device)
67 torch.nn.init.uniform_(weights_energy.weight)
68
69 criterion = VirtualOutlierSynthesizingRegLoss(
70     phi,
71     weights_energy,
72     device=device,
73     num_classes=num_classes,
74     num_input_last_layer=128,
75     fc=model.fc,
76     sample_number=10000,
77     select=1,
78     sample_from=1000,
79     alpha=0.1,
80 )

Train model for some epochs

 84 optimizer = torch.optim.SGD(
 85     list(model.parameters()) + list(phi.parameters()) + list(weights_energy.parameters()),
 86     lr=lr,
 87     momentum=0.9,
 88     weight_decay=5e-4,
 89 )
 90
 91
 92 # setup scheduler for optimizer (recommended)
 93 scheduler = CosineAnnealingLR(
 94     optimizer,
 95     T_max=num_epochs * len(loader),
 96 )
 97 loss_ema = 0
 98
 99 for epoch in range(num_epochs):
100     for n, (x, y) in enumerate(loader):
101         optimizer.zero_grad()
102         y, x = y.to(device), x.to(device)
103
104         features = model.features(x)
105         y_hat = model.fc(features)
106         loss = criterion(y_hat, features, y)
107
108         loss.backward()
109         optimizer.step()
110         scheduler.step()
111
112         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
113
114         if n % 10 == 0:
115             print(f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f}")

Evaluate

119 print("Evaluating")
120 model.eval()
121 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=64)
122 detector_weightedEBO = WeightedEBO(model, weights_energy)
123 detector_energyBased = EnergyBased(model)
124 metrics_weightedEBO = OODMetrics()
125 metrics_energyBased = OODMetrics()
126
127 with torch.no_grad():
128     for n, (x, y) in enumerate(test_loader):
129         y, x = y.to(device), x.to(device)
130         y_hat = model(x)
131         o = detector_weightedEBO.predict_features(y_hat)
132         o1 = detector_energyBased.predict_features(y_hat)
133
134         metrics_weightedEBO.update(o, y)
135         metrics_energyBased.update(o1, y)
136         if n % 10 == 0:
137             print(f"Epoch {epoch:03d} [{n:05d}/{len(test_loader):05d}] ")
138
139 print(f"WeightedEBO: {metrics_weightedEBO.compute()}")
140 print(f"EnergyBased: {metrics_energyBased.compute()}")

Output: WeightedEBO: {‘AUROC’: 0.9192541837692261, ‘AUPR-IN’: 0.8389347195625305, ‘AUPR-OUT’: 0.954131007194519, ‘FPR95TPR’: 0.2897000014781952} EnergyBased: {‘AUROC’: 0.9227883815765381, ‘AUPR-IN’: 0.8493221998214722, ‘AUPR-OUT’: 0.956129789352417, ‘FPR95TPR’: 0.2799000144004822}

Gallery generated by Sphinx-Gallery