Class Anchor Clustering

Class Anchor Clustering (CAC) can be seen as a multi-class generalization of Deep One-Class Learning, where there are several centers \(\{\mu_1, \mu_2, ..., \mu_y\}\) in the output space of the model, one for each class. During training, the representation \(f_{\theta}(x)\) from class \(y\) is drawn towards the corresponding center \(\mu_y\).

Here, we train the model for 10 epochs on the CIFAR10 dataset, using a backbone pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation.

15 import torch
16 from torch.optim import Adam
17 from torch.optim.lr_scheduler import CosineAnnealingLR
18 from torch.utils.data import DataLoader
19 from torchmetrics import Accuracy
20 from torchvision.datasets import CIFAR10
21 from tqdm import tqdm
22
23 from pytorch_ood.dataset.img import Textures
24 from pytorch_ood.loss import CACLoss
25 from pytorch_ood.model import WideResNet
26 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
27
28 fix_random_seed(123)
29
30 n_epochs = 10
31 device = "cuda:0"
32
33 trans = WideResNet.transform_for("imagenet32-nocifar")
34
35 # setup IN training data
36 dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans)
37
38 # setup IN test data
39 dataset_in_test = CIFAR10(root="data", train=False, transform=trans)
40
41 # setup OOD test data, use ToUnknown() to mark labels as OOD
42 dataset_out_test = Textures(
43     root="data", download=True, transform=trans, target_transform=ToUnknown()
44 )
45
46 # create data loaders
47 train_loader = DataLoader(dataset_in_train, batch_size=64, shuffle=True, num_workers=16)
48 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=64, num_workers=16)

Create DNN, pretrained on the imagenet excluding cifar10 classes. We have to replace the final layer to match the number of classes.

53 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
54 model.fc = torch.nn.Linear(model.fc.in_features, 10)
55 model.to(device)
56
57 opti = Adam(model.parameters())
58 criterion = CACLoss(n_classes=10, magnitude=5, alpha=2).to(device)
59 scheduler = CosineAnnealingLR(opti, T_max=n_epochs * len(train_loader))

Define a function that evaluates the model

65 def test():
66     metrics = OODMetrics()
67     acc = Accuracy(num_classes=10, task="multiclass")
68
69     model.eval()
70
71     with torch.no_grad(), tqdm(test_loader, desc="Testing") as bar:
72         for x, y in bar:
73             # calculate embeddings
74             z = model(x.to(device))
75             # calculate the distance of each embedding to each center
76             distances = criterion.distance(z).cpu()
77             # the CAC Loss proposes its own method for score calculation.
78             # We could, however, also use the minimum distance.
79             metrics.update(CACLoss.score(distances), y)
80             known = is_known(y)
81             if known.any():
82                 acc.update(distances[known].min(dim=1).indices, y[known])
83
84     print(metrics.compute())
85     print(f"Accuracy: {acc.compute().item():.2%}")
86     model.train()

Start training

 92 for epoch in range(n_epochs):
 93     loss_ema = 0
 94
 95     with tqdm(train_loader, desc=f"Epoch {epoch}") as bar:
 96         for x, y in bar:
 97             # calculate embeddings
 98             z = model(x.to(device))
 99             # calculate the distance of each embedding to each center
100             distances = criterion.distance(z)
101             # calculate CAC loss, based on distances to centers
102             loss = criterion(distances, y.cuda())
103             opti.zero_grad()
104             loss.backward()
105             opti.step()
106             scheduler.step()
107
108             loss_ema = loss.item() if not loss_ema else 0.99 * loss_ema + 0.01 * loss.item()
109             bar.set_postfix_str(f"loss: {loss_ema:.3f} lr: {scheduler.get_last_lr()[0]:.6f}")
110
111         test()

{‘AUROC’: 0.8958120346069336, ‘AUPR-IN’: 0.8456918001174927, ‘AUPR-OUT’: 0.8196930885314941, ‘FPR95TPR’: 0.5187000036239624}

Accuracy: 93.80%

Gallery generated by Sphinx-Gallery