Class Anchor Clustering

Class Anchor Clustering (CAC) can be seen as a multi-class generalization of Deep One-Class Learning, where there are several centers \(\{\mu_1, \mu_2, ..., \mu_y\}\) in the output space of the model, one for each class. During training, the representation \(f_{\theta}(x)\) from class \(y\) is drawn towards the corresponding center \(\mu_y\).

Here, we train the model for 10 epochs on the CIFAR10 dataset, using a backbone pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.

14 import torch
15 from torch.optim import Adam
16 from torch.optim.lr_scheduler import CosineAnnealingLR
17 from torch.utils.data import DataLoader
18 from torchmetrics import Accuracy
19 from torchvision.datasets import CIFAR10
20 from tqdm import tqdm
21
22 from pytorch_ood.dataset.img import Textures
23 from pytorch_ood.loss import CACLoss
24 from pytorch_ood.model import WideResNet
25 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
26
27 fix_random_seed(123)
28
29 n_epochs = 10
30 device = "cuda:0"
31
32 trans = WideResNet.transform_for("imagenet32-nocifar")
33
34 # setup IN training data
35 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
36
37 # setup IN test data
38 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
39
40 # setup OOD test data, use ToUnknown() to mark labels as OOD
41 dataset_out_test = Textures(
42     root="data", download=True, transform=trans, target_transform=ToUnknown()
43 )
44
45 # create data loaders
46 train_loader = DataLoader(dataset_in_train, batch_size=64, shuffle=True, num_workers=16)
47 test_loader = DataLoader(
48     dataset_in_test + dataset_out_test, batch_size=64, num_workers=16
49 )

Create DNN, pretrained on the imagenet excluding cifar10 classes. We have to replace the final layer to match the number of classes.

54 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
55 model.fc = torch.nn.Linear(model.fc.in_features, 10)
56 model.to(device)
57
58 opti = Adam(model.parameters())
59 criterion = CACLoss(n_classes=10, magnitude=5, alpha=2).to(device)
60 scheduler = CosineAnnealingLR(opti, T_max=n_epochs * len(train_loader))

Define a function that evaluates the model

66 def test():
67     metrics = OODMetrics()
68     acc = Accuracy(num_classes=10)
69
70     model.eval()
71
72     with torch.no_grad(), tqdm(test_loader, desc="Testing") as bar:
73         for x, y in bar:
74             # calculate embeddings
75             z = model(x.to(device))
76             # calculate the distance of each embedding to each center
77             distances = criterion.distance(z).cpu()
78             # the CAC Loss proposes its own method for score calculation.
79             # We could, however, also use the minimum distance.
80             metrics.update(CACLoss.score(distances), y)
81             known = is_known(y)
82             if known.any():
83                 acc.update(distances[known].min(dim=1).indices, y[known])
84
85     print(metrics.compute())
86     print(f"Accuracy: {acc.compute().item():.2%}")
87     model.train()

Start training

 93 for epoch in range(n_epochs):
 94     print(f"Epoch {epoch}")
 95     loss_ema = 0
 96
 97     with tqdm(train_loader, desc=f"Epoch {epoch}") as bar:
 98         for x, y in bar:
 99             # calculate embeddings
100             z = model(x.to(device))
101             # calculate the distance of each embedding to each center
102             distances = criterion.distance(z)
103             # calculate CAC loss, based on distances to centers
104             loss = criterion(distances, y.cuda())
105             opti.zero_grad()
106             loss.backward()
107             opti.step()
108             scheduler.step()
109
110             loss_ema = (
111                 loss.item() if not loss_ema else 0.99 * loss_ema + 0.01 * loss.item()
112             )
113             bar.set_postfix_str(
114                 f"loss: {loss_ema:.3f} lr: {scheduler.get_last_lr()[0]:.6f}"
115             )
116
117     test()

Gallery generated by Sphinx-Gallery