Note
Go to the end to download the full example code.
MCHAD
Multi Class Hypersphere Anomaly Detection (MCHAD) can be seen as a
multi-class generalization of Deep One-Class Learning, where there are
several centers \(\{\mu_1, \mu_2, ..., \mu_y\}\) in the output space of the model, one for each class.
During training, the representation \(f_{\theta}(x)\) from class \(y\) is drawn
towards the corresponding center \(\mu_y\).
In contrast to Class Anchor Clustering, the position of the class-centers can be learned, and the dimensionality of the output space can be chosen freely. Also, the method is able to incorporate outliers into the training. On the downside, MCHAD has more hyperparameters.
Here, we train the model for 10 epochs on the CIFAR10 dataset, using a backbone pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation. We use the TinyImages300k as training outlier data.
You can run this example with:
python examples/loss/supervised/mchad.py
27 import torch
28 from torch.optim.lr_scheduler import CosineAnnealingLR
29 from torch.optim import Adam
30 from torch.utils.data import DataLoader, random_split
31 from torchmetrics import Accuracy
32 from torchvision.datasets import CIFAR10
33 from tqdm import tqdm
34 import math
35
36 from pytorch_ood.dataset.img import Textures, TinyImages300k
37 from pytorch_ood.loss import MCHADLoss
38 from pytorch_ood.model import WideResNet
39 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
40
41 fix_random_seed(123)
42
43 data_root = "data"
44 n_epochs = 10
45 device = "cuda:0"
46 embedding_dim = 7 # dimensionality of output space
47 margin = math.sqrt(embedding_dim)
48 batch_size = 256
49
50 trans = WideResNet.transform_for("imagenet32-nocifar")
51
52 # setup IN training data
53 data_in_train = CIFAR10(root=data_root, train=True, download=True, transform=trans)
54
55 # setup OOD training data, same size as IN training data
56 tiny300k = TinyImages300k(
57 root=data_root, download=True, transform=trans, target_transform=ToUnknown()
58 )
59
60 # setup IN test data
61 data_in_test = CIFAR10(root=data_root, train=False, transform=trans)
62
63 # setup OOD test data, use ToUnknown() to mark labels as OOD
64 data_out_test = Textures(
65 root=data_root, download=True, transform=trans, target_transform=ToUnknown()
66 )
67
68
69 # create data loaders
70 test_loader = DataLoader(
71 data_in_test + data_out_test, batch_size=batch_size, num_workers=16
72 )
73
74 data_out_train, _ = random_split(
75 tiny300k, [len(data_in_train), len(tiny300k) - len(data_in_train)]
76 )
77 train_loader = DataLoader(
78 data_in_train + data_out_train, batch_size=batch_size, shuffle=True, num_workers=16
79 )
Create DNN, pretrained on the imagenet excluding cifar10 classes. We have to replace the final layer to match the number of classes.
84 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
85 model.fc = torch.nn.Linear(model.fc.in_features, embedding_dim)
86 model.to(device)
87
88 opti = Adam(model.parameters())
89 criterion = MCHADLoss(
90 n_classes=10, n_dim=embedding_dim, weight_oe=0.01, weight_center=2, margin=margin
91 ).to(device)
92 scheduler = CosineAnnealingLR(opti, T_max=n_epochs * len(train_loader))
Define a function that evaluates the model
98 def test():
99 metrics = OODMetrics()
100 acc = Accuracy(num_classes=10)
101
102 model.eval()
103
104 with torch.no_grad(), tqdm(test_loader, desc="Testing") as bar:
105 for x, y in bar:
106 # calculate embeddings
107 z = model(x.to(device))
108 # calculate the distance of each embedding to each center
109 distances = criterion.distance(z).cpu()
110 metrics.update(distances.min(dim=1).values, y)
111 known = is_known(y)
112 if known.any():
113 acc.update(distances[known].min(dim=1).indices, y[known])
114
115 print(metrics.compute())
116 print(f"Accuracy: {acc.compute().item():.2%}")
117 model.train()
Start training
123 for epoch in range(n_epochs):
124 print(f"Epoch {epoch}")
125 loss_ema = None
126
127 with tqdm(train_loader, desc=f"Epoch {epoch}") as bar:
128 for x, y in bar:
129 # calculate embeddings
130 z = model(x.to(device))
131 # calculate the distance of each embedding to each center
132 distances = criterion.distance(z)
133 # calculate MCHAD loss, based on distances to centers
134 loss = criterion(distances, y.cuda())
135 opti.zero_grad()
136 loss.backward()
137 opti.step()
138 scheduler.step()
139
140 loss_ema = (
141 loss.item() if not loss_ema else 0.99 * loss_ema + 0.01 * loss.item()
142 )
143 bar.set_postfix_str(
144 f"loss: {loss_ema:.3f} lr: {scheduler.get_last_lr()[0]:.6f}"
145 )
146
147 test()
148
149 # create new random split
150 data_out_train, _ = random_split(
151 tiny300k, [len(data_in_train), len(tiny300k) - len(data_in_train)]
152 )
153 train_loader = DataLoader(
154 data_in_train + data_out_train,
155 batch_size=batch_size,
156 shuffle=True,
157 num_workers=16,
158 )