MCHAD

Multi Class Hypersphere Anomaly Detection (MCHAD) can be seen as a multi-class generalization of Deep One-Class Learning, where there are several centers \(\{\mu_1, \mu_2, ..., \mu_y\}\) in the output space of the model, one for each class. During training, the representation \(f_{\theta}(x)\) from class \(y\) is drawn towards the corresponding center \(\mu_y\).

In contrast to Class Anchor Clustering, the position of the class-centers can be learned, and the dimensionality of the output space can be chosen freely. Also, the method is able to incorporate outliers into the training. On the downside, MCHAD has more hyperparameters.

Here, we train the model for 10 epochs on the CIFAR10 dataset, using a backbone pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation. We use the TinyImages300k as training outlier data.

You can run this example with:

python examples/loss/supervised/mchad.py
27 import torch
28 from torch.optim.lr_scheduler import CosineAnnealingLR
29 from torch.optim import Adam
30 from torch.utils.data import DataLoader, random_split
31 from torchmetrics import Accuracy
32 from torchvision.datasets import CIFAR10
33 from tqdm import tqdm
34 import math
35
36 from pytorch_ood.dataset.img import Textures, TinyImages300k
37 from pytorch_ood.loss import MCHADLoss
38 from pytorch_ood.model import WideResNet
39 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
40
41 fix_random_seed(123)
42
43 data_root = "data"
44 n_epochs = 10
45 device = "cuda:0"
46 embedding_dim = 7  # dimensionality of output space
47 margin = math.sqrt(embedding_dim)
48 batch_size = 256
49
50 trans = WideResNet.transform_for("imagenet32-nocifar")
51
52 # setup IN training data
53 data_in_train = CIFAR10(root=data_root, train=True, download=True, transform=trans)
54
55 # setup OOD training data, same size as IN training data
56 tiny300k = TinyImages300k(
57     root=data_root, download=True, transform=trans, target_transform=ToUnknown()
58 )
59
60 # setup IN test data
61 data_in_test = CIFAR10(root=data_root, train=False, transform=trans)
62
63 # setup OOD test data, use ToUnknown() to mark labels as OOD
64 data_out_test = Textures(
65     root=data_root, download=True, transform=trans, target_transform=ToUnknown()
66 )
67
68
69 # create data loaders
70 test_loader = DataLoader(
71     data_in_test + data_out_test, batch_size=batch_size, num_workers=16
72 )
73
74 data_out_train, _ = random_split(
75     tiny300k, [len(data_in_train), len(tiny300k) - len(data_in_train)]
76 )
77 train_loader = DataLoader(
78     data_in_train + data_out_train, batch_size=batch_size, shuffle=True, num_workers=16
79 )

Create DNN, pretrained on the imagenet excluding cifar10 classes. We have to replace the final layer to match the number of classes.

84 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
85 model.fc = torch.nn.Linear(model.fc.in_features, embedding_dim)
86 model.to(device)
87
88 opti = Adam(model.parameters())
89 criterion = MCHADLoss(
90     n_classes=10, n_dim=embedding_dim, weight_oe=0.01, weight_center=2, margin=margin
91 ).to(device)
92 scheduler = CosineAnnealingLR(opti, T_max=n_epochs * len(train_loader))

Define a function that evaluates the model

 98 def test():
 99     metrics = OODMetrics()
100     acc = Accuracy(num_classes=10)
101
102     model.eval()
103
104     with torch.no_grad(), tqdm(test_loader, desc="Testing") as bar:
105         for x, y in bar:
106             # calculate embeddings
107             z = model(x.to(device))
108             # calculate the distance of each embedding to each center
109             distances = criterion.distance(z).cpu()
110             metrics.update(distances.min(dim=1).values, y)
111             known = is_known(y)
112             if known.any():
113                 acc.update(distances[known].min(dim=1).indices, y[known])
114
115     print(metrics.compute())
116     print(f"Accuracy: {acc.compute().item():.2%}")
117     model.train()

Start training

123 for epoch in range(n_epochs):
124     print(f"Epoch {epoch}")
125     loss_ema = None
126
127     with tqdm(train_loader, desc=f"Epoch {epoch}") as bar:
128         for x, y in bar:
129             # calculate embeddings
130             z = model(x.to(device))
131             # calculate the distance of each embedding to each center
132             distances = criterion.distance(z)
133             # calculate MCHAD loss, based on distances to centers
134             loss = criterion(distances, y.cuda())
135             opti.zero_grad()
136             loss.backward()
137             opti.step()
138             scheduler.step()
139
140             loss_ema = (
141                 loss.item() if not loss_ema else 0.99 * loss_ema + 0.01 * loss.item()
142             )
143             bar.set_postfix_str(
144                 f"loss: {loss_ema:.3f} lr: {scheduler.get_last_lr()[0]:.6f}"
145             )
146
147     test()
148
149     # create new random split
150     data_out_train, _ = random_split(
151         tiny300k, [len(data_in_train), len(tiny300k) - len(data_in_train)]
152     )
153     train_loader = DataLoader(
154         data_in_train + data_out_train,
155         batch_size=batch_size,
156         shuffle=True,
157         num_workers=16,
158     )

Gallery generated by Sphinx-Gallery