MCHAD

Multi Class Hypersphere Anomaly Detection (MCHAD) can be seen as a multi-class generalization of Deep One-Class Learning, where there are several centers \(\{\mu_1, \mu_2, ..., \mu_y\}\) in the output space of the model, one for each class. During training, the representation \(f_{\theta}(x)\) from class \(y\) is drawn towards the corresponding center \(\mu_y\).

In contrast to Class Anchor Clustering, the position of the class-centers can be learned, and the dimensionality of the output space can be chosen freely. Also, the method is able to incorporate outliers into the training. On the downside, MCHAD has more hyperparameters.

Here, we train the model for 10 epochs on the CIFAR10 dataset, using a backbone pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation. We use the TinyImages300k as training outlier data.

You can run this example with:

python examples/loss/supervised/mchad.py
28 import torch
29 from torch.optim.lr_scheduler import CosineAnnealingLR
30 from torch.optim import Adam
31 from torch.utils.data import DataLoader, random_split
32 from torchmetrics import Accuracy
33 from torchvision.datasets import CIFAR10
34 from tqdm import tqdm
35 import math
36
37 from pytorch_ood.dataset.img import Textures, TinyImages300k
38 from pytorch_ood.loss import MCHADLoss
39 from pytorch_ood.model import WideResNet
40 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
41
42 fix_random_seed(123)
43
44 data_root = "data"
45 n_epochs = 10
46 device = "cuda:0"
47 embedding_dim = 7  # dimensionality of output space
48 margin = math.sqrt(embedding_dim)
49 batch_size = 256
50
51 trans = WideResNet.transform_for("imagenet32-nocifar")
52
53 # setup IN training data
54 data_in_train = CIFAR10(root=data_root, train=True, download=True, transform=trans)
55
56 # setup OOD training data, same size as IN training data
57 tiny300k = TinyImages300k(
58     root=data_root, download=True, transform=trans, target_transform=ToUnknown()
59 )
60
61 # setup IN test data
62 data_in_test = CIFAR10(root=data_root, train=False, transform=trans)
63
64 # setup OOD test data, use ToUnknown() to mark labels as OOD
65 data_out_test = Textures(
66     root=data_root, download=True, transform=trans, target_transform=ToUnknown()
67 )
68
69
70 # create data loaders
71 test_loader = DataLoader(data_in_test + data_out_test, batch_size=batch_size, num_workers=16)
72
73 data_out_train, _ = random_split(
74     tiny300k, [len(data_in_train), len(tiny300k) - len(data_in_train)]
75 )
76 train_loader = DataLoader(
77     data_in_train + data_out_train, batch_size=batch_size, shuffle=True, num_workers=16
78 )

Create DNN, pretrained on the imagenet excluding cifar10 classes. We have to replace the final layer to match the number of classes.

83 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
84 model.fc = torch.nn.Linear(model.fc.in_features, embedding_dim)
85 model.to(device)
86
87 opti = Adam(model.parameters())
88 criterion = MCHADLoss(
89     n_classes=10, n_dim=embedding_dim, weight_oe=0.01, weight_center=2, margin=margin
90 ).to(device)
91 scheduler = CosineAnnealingLR(opti, T_max=n_epochs * len(train_loader))

Define a function that evaluates the model

 97 def test():
 98     metrics = OODMetrics()
 99     acc = Accuracy(num_classes=10, task="multiclass")
100
101     model.eval()
102
103     with torch.no_grad(), tqdm(test_loader, desc="Testing") as bar:
104         for x, y in bar:
105             # calculate embeddings
106             z = model(x.to(device))
107             # calculate the distance of each embedding to each center
108             distances = criterion.distance(z).cpu()
109             metrics.update(distances.min(dim=1).values, y)
110             known = is_known(y)
111             if known.any():
112                 acc.update(distances[known].min(dim=1).indices, y[known])
113
114     print(metrics.compute())
115     print(f"Accuracy: {acc.compute().item():.2%}")
116     model.train()

Start training

122 for epoch in range(n_epochs):
123     loss_ema = None
124
125     with tqdm(train_loader, desc=f"Epoch {epoch}") as bar:
126         for x, y in bar:
127             # calculate embeddings
128             z = model(x.to(device))
129             # calculate the distance of each embedding to each center
130             distances = criterion.distance(z)
131             # calculate MCHAD loss, based on distances to centers
132             loss = criterion(distances, y.cuda())
133             opti.zero_grad()
134             loss.backward()
135             opti.step()
136             scheduler.step()
137
138             loss_ema = loss.item() if not loss_ema else 0.99 * loss_ema + 0.01 * loss.item()
139             bar.set_postfix_str(f"loss: {loss_ema:.3f} lr: {scheduler.get_last_lr()[0]:.6f}")
140
141         test()
142
143     # create new random split
144     data_out_train, _ = random_split(
145         tiny300k, [len(data_in_train), len(tiny300k) - len(data_in_train)]
146     )
147     train_loader = DataLoader(
148         data_in_train + data_out_train,
149         batch_size=batch_size,
150         shuffle=True,
151         num_workers=16,
152     )

Gallery generated by Sphinx-Gallery