Virtual Outlier Synthesizer Loss

We train a model with Virtual Outlier Synthesizer Loss on the CIFAR10.

We then use the WeightedEBO OOD detector.

We can use a model pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.

11 import numpy as np
12 import torch
13 from torch.optim.lr_scheduler import CosineAnnealingLR
14 from torch.utils.data import DataLoader
15 from torchvision.datasets import CIFAR10
16
17 from pytorch_ood.dataset.img import Textures
18 from pytorch_ood.detector import EnergyBased, WeightedEBO
19 from pytorch_ood.loss import VirtualOutlierSynthesizingRegLoss
20 from pytorch_ood.model import WideResNet
21 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
22
23 device = "cuda:0"
24 batch_size = 128
25 num_epochs = 10
26 lr = 0.1
27 num_classes = 10
28
29 fix_random_seed(12345)
30 g = torch.Generator()
31 g.manual_seed(0)

Setup datasets, train on cifar.

36 trans = WideResNet.transform_for("cifar10-pt")
37
38 dataset = CIFAR10(root="data", train=True, transform=trans, download=True)
39
40 # setup IN test data
41 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
42
43 # setup OOD test data, use ToUnknown() to mark labels as OOD
44 dataset_out_test = Textures(
45     root="data", download=True, transform=trans, target_transform=ToUnknown()
46 )
47
48
49 loader = DataLoader(
50     dataset,
51     batch_size=batch_size,
52     shuffle=True,
53     num_workers=10,
54     worker_init_fn=fix_random_seed,
55     generator=g,
56 )

Setup model

60 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device)

Create neural network functions (layers)

64 phi = torch.nn.Linear(1, 2).to(device)
65 weights_energy = torch.nn.Linear(num_classes, 1).to(device)
66 torch.nn.init.uniform_(weights_energy.weight)
67
68 criterion = VirtualOutlierSynthesizingRegLoss(
69     phi,
70     weights_energy,
71     device=device,
72     num_classes=num_classes,
73     num_input_last_layer=128,
74     fc=model.fc,
75     sample_number=10000,
76     select=1,
77     sample_from=1000,
78     alpha=0.1,
79 )

Train model for some epochs

 83 optimizer = torch.optim.SGD(
 84     list(model.parameters()) + list(phi.parameters()) + list(weights_energy.parameters()),
 85     lr=lr,
 86     momentum=0.9,
 87     weight_decay=5e-4,
 88 )
 89
 90
 91 # setup scheduler for optimizer (recommended)
 92 scheduler = CosineAnnealingLR(
 93     optimizer,
 94     T_max=num_epochs * len(loader),
 95 )
 96 loss_ema = 0
 97
 98 for epoch in range(num_epochs):
 99     for n, (x, y) in enumerate(loader):
100         optimizer.zero_grad()
101         y, x = y.to(device), x.to(device)
102
103         features = model.features(x)
104         y_hat = model.fc(features)
105         loss = criterion(y_hat, features, y)
106
107         loss.backward()
108         optimizer.step()
109         scheduler.step()
110
111         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
112
113         if n % 10 == 0:
114             print(f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f}")

Evaluate

118 print("Evaluating")
119 model.eval()
120 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=64)
121 detector_weightedEBO = WeightedEBO(model, weights_energy)
122 detector_energyBased = EnergyBased(model)
123 metrics_weightedEBO = OODMetrics()
124 metrics_energyBased = OODMetrics()
125
126 with torch.no_grad():
127     for n, (x, y) in enumerate(test_loader):
128         y, x = y.to(device), x.to(device)
129         y_hat = model(x)
130         o = detector_weightedEBO.predict_features(y_hat)
131         o1 = detector_energyBased.predict_features(y_hat)
132
133         metrics_weightedEBO.update(o, y)
134         metrics_energyBased.update(o1, y)
135         if n % 10 == 0:
136             print(f"Epoch {epoch:03d} [{n:05d}/{len(test_loader):05d}] ")
137
138 print(f"WeightedEBO: {metrics_weightedEBO.compute()}")
139 print(f"EnergyBased: {metrics_energyBased.compute()}")

Output: WeightedEBO: {‘AUROC’: 0.9192541837692261, ‘AUPR-IN’: 0.8389347195625305, ‘AUPR-OUT’: 0.954131007194519, ‘FPR95TPR’: 0.2897000014781952} EnergyBased: {‘AUROC’: 0.9227883815765381, ‘AUPR-IN’: 0.8493221998214722, ‘AUPR-OUT’: 0.956129789352417, ‘FPR95TPR’: 0.2799000144004822}

Gallery generated by Sphinx-Gallery