Note
Go to the end to download the full example code.
Class Anchor Clustering
Class Anchor Clustering (CAC) can be seen as a multi-class generalization of Deep
One-Class Learning, where there are
several centers \(\{\mu_1, \mu_2, ..., \mu_y\}\) in the output space of the model, one for each class.
During training, the representation \(f_{\theta}(x)\) from class \(y\) is drawn
towards the corresponding center \(\mu_y\).
Here, we train the model for 10 epochs on the CIFAR10 dataset, using a backbone pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.
14 import torch
15 from torch.optim import Adam
16 from torch.optim.lr_scheduler import CosineAnnealingLR
17 from torch.utils.data import DataLoader
18 from torchmetrics import Accuracy
19 from torchvision.datasets import CIFAR10
20 from tqdm import tqdm
21
22 from pytorch_ood.dataset.img import Textures
23 from pytorch_ood.loss import CACLoss
24 from pytorch_ood.model import WideResNet
25 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
26
27 fix_random_seed(123)
28
29 n_epochs = 10
30 device = "cuda:0"
31
32 trans = WideResNet.transform_for("imagenet32-nocifar")
33
34 # setup IN training data
35 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
36
37 # setup IN test data
38 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
39
40 # setup OOD test data, use ToUnknown() to mark labels as OOD
41 dataset_out_test = Textures(
42 root="data", download=True, transform=trans, target_transform=ToUnknown()
43 )
44
45 # create data loaders
46 train_loader = DataLoader(dataset_in_train, batch_size=64, shuffle=True, num_workers=16)
47 test_loader = DataLoader(
48 dataset_in_test + dataset_out_test, batch_size=64, num_workers=16
49 )
Create DNN, pretrained on the imagenet excluding cifar10 classes. We have to replace the final layer to match the number of classes.
54 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
55 model.fc = torch.nn.Linear(model.fc.in_features, 10)
56 model.to(device)
57
58 opti = Adam(model.parameters())
59 criterion = CACLoss(n_classes=10, magnitude=5, alpha=2).to(device)
60 scheduler = CosineAnnealingLR(opti, T_max=n_epochs * len(train_loader))
Define a function that evaluates the model
66 def test():
67 metrics = OODMetrics()
68 acc = Accuracy(num_classes=10)
69
70 model.eval()
71
72 with torch.no_grad(), tqdm(test_loader, desc="Testing") as bar:
73 for x, y in bar:
74 # calculate embeddings
75 z = model(x.to(device))
76 # calculate the distance of each embedding to each center
77 distances = criterion.distance(z).cpu()
78 # the CAC Loss proposes its own method for score calculation.
79 # We could, however, also use the minimum distance.
80 metrics.update(CACLoss.score(distances), y)
81 known = is_known(y)
82 if known.any():
83 acc.update(distances[known].min(dim=1).indices, y[known])
84
85 print(metrics.compute())
86 print(f"Accuracy: {acc.compute().item():.2%}")
87 model.train()
Start training
93 for epoch in range(n_epochs):
94 print(f"Epoch {epoch}")
95 loss_ema = 0
96
97 with tqdm(train_loader, desc=f"Epoch {epoch}") as bar:
98 for x, y in bar:
99 # calculate embeddings
100 z = model(x.to(device))
101 # calculate the distance of each embedding to each center
102 distances = criterion.distance(z)
103 # calculate CAC loss, based on distances to centers
104 loss = criterion(distances, y.cuda())
105 opti.zero_grad()
106 loss.backward()
107 opti.step()
108 scheduler.step()
109
110 loss_ema = (
111 loss.item() if not loss_ema else 0.99 * loss_ema + 0.01 * loss.item()
112 )
113 bar.set_postfix_str(
114 f"loss: {loss_ema:.3f} lr: {scheduler.get_last_lr()[0]:.6f}"
115 )
116
117 test()