Note
Go to the end to download the full example code.
Class Anchor Clustering
Class Anchor Clustering (CAC) can be seen as a multi-class generalization of Deep
One-Class Learning, where there are
several centers \(\{\mu_1, \mu_2, ..., \mu_y\}\) in the output space of the model, one for each class.
During training, the representation \(f_{\theta}(x)\) from class \(y\) is drawn
towards the corresponding center \(\mu_y\).
Here, we train the model for 10 epochs on the CIFAR10 dataset, using a backbone pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.
15 import torch
16 from torch.optim import Adam
17 from torch.optim.lr_scheduler import CosineAnnealingLR
18 from torch.utils.data import DataLoader
19 from torchmetrics import Accuracy
20 from torchvision.datasets import CIFAR10
21 from tqdm import tqdm
22
23 from pytorch_ood.dataset.img import Textures
24 from pytorch_ood.loss import CACLoss
25 from pytorch_ood.model import WideResNet
26 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
27
28 fix_random_seed(123)
29
30 n_epochs = 10
31 device = "cuda:0"
32
33 trans = WideResNet.transform_for("imagenet32-nocifar")
34
35 # setup IN training data
36 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
37
38 # setup IN test data
39 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
40
41 # setup OOD test data, use ToUnknown() to mark labels as OOD
42 dataset_out_test = Textures(
43 root="data", download=True, transform=trans, target_transform=ToUnknown()
44 )
45
46 # create data loaders
47 train_loader = DataLoader(dataset_in_train, batch_size=64, shuffle=True, num_workers=16)
48 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=64, num_workers=16)
Create DNN, pretrained on the imagenet excluding cifar10 classes. We have to replace the final layer to match the number of classes.
53 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
54 model.fc = torch.nn.Linear(model.fc.in_features, 10)
55 model.to(device)
56
57 opti = Adam(model.parameters())
58 criterion = CACLoss(n_classes=10, magnitude=5, alpha=2).to(device)
59 scheduler = CosineAnnealingLR(opti, T_max=n_epochs * len(train_loader))
Define a function that evaluates the model
65 def test():
66 metrics = OODMetrics()
67 acc = Accuracy(num_classes=10, task="multiclass")
68
69 model.eval()
70
71 with torch.no_grad(), tqdm(test_loader, desc="Testing") as bar:
72 for x, y in bar:
73 # calculate embeddings
74 z = model(x.to(device))
75 # calculate the distance of each embedding to each center
76 distances = criterion.distance(z).cpu()
77 # the CAC Loss proposes its own method for score calculation.
78 # We could, however, also use the minimum distance.
79 metrics.update(CACLoss.score(distances), y)
80 known = is_known(y)
81 if known.any():
82 acc.update(distances[known].min(dim=1).indices, y[known])
83
84 print(metrics.compute())
85 print(f"Accuracy: {acc.compute().item():.2%}")
86 model.train()
Start training
92 for epoch in range(n_epochs):
93 loss_ema = 0
94
95 with tqdm(train_loader, desc=f"Epoch {epoch}") as bar:
96 for x, y in bar:
97 # calculate embeddings
98 z = model(x.to(device))
99 # calculate the distance of each embedding to each center
100 distances = criterion.distance(z)
101 # calculate CAC loss, based on distances to centers
102 loss = criterion(distances, y.cuda())
103 opti.zero_grad()
104 loss.backward()
105 opti.step()
106 scheduler.step()
107
108 loss_ema = loss.item() if not loss_ema else 0.99 * loss_ema + 0.01 * loss.item()
109 bar.set_postfix_str(f"loss: {loss_ema:.3f} lr: {scheduler.get_last_lr()[0]:.6f}")
110
111 test()
{‘AUROC’: 0.8958120346069336, ‘AUPR-IN’: 0.8456918001174927, ‘AUPR-OUT’: 0.8196930885314941, ‘FPR95TPR’: 0.5187000036239624}
Accuracy: 93.80%