Note
Go to the end to download the full example code.
MCHAD
Multi Class Hypersphere Anomaly Detection (MCHAD) can be seen as a
multi-class generalization of Deep One-Class Learning, where there are
several centers \(\{\mu_1, \mu_2, ..., \mu_y\}\) in the output space of the model, one for each class.
During training, the representation \(f_{\theta}(x)\) from class \(y\) is drawn
towards the corresponding center \(\mu_y\).
In contrast to Class Anchor Clustering, the position of the class-centers can be learned, and the dimensionality of the output space can be chosen freely. Also, the method is able to incorporate outliers into the training. On the downside, MCHAD has more hyperparameters.
Here, we train the model for 10 epochs on the CIFAR10 dataset, using a backbone pre-trained on the \(32 \times 32\) resized version of the ImageNet as a foundation. We use the TinyImages300k as training outlier data.
You can run this example with:
python examples/loss/supervised/mchad.py
28 import torch
29 from torch.optim.lr_scheduler import CosineAnnealingLR
30 from torch.optim import Adam
31 from torch.utils.data import DataLoader, random_split
32 from torchmetrics import Accuracy
33 from torchvision.datasets import CIFAR10
34 from tqdm import tqdm
35 import math
36
37 from pytorch_ood.dataset.img import Textures, TinyImages300k
38 from pytorch_ood.loss import MCHADLoss
39 from pytorch_ood.model import WideResNet
40 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
41
42 fix_random_seed(123)
43
44 data_root = "data"
45 n_epochs = 10
46 device = "cuda:0"
47 embedding_dim = 7 # dimensionality of output space
48 margin = math.sqrt(embedding_dim)
49 batch_size = 256
50
51 trans = WideResNet.transform_for("imagenet32-nocifar")
52
53 # setup IN training data
54 data_in_train = CIFAR10(root=data_root, train=True, download=True, transform=trans)
55
56 # setup OOD training data, same size as IN training data
57 tiny300k = TinyImages300k(
58 root=data_root, download=True, transform=trans, target_transform=ToUnknown()
59 )
60
61 # setup IN test data
62 data_in_test = CIFAR10(root=data_root, train=False, transform=trans)
63
64 # setup OOD test data, use ToUnknown() to mark labels as OOD
65 data_out_test = Textures(
66 root=data_root, download=True, transform=trans, target_transform=ToUnknown()
67 )
68
69
70 # create data loaders
71 test_loader = DataLoader(data_in_test + data_out_test, batch_size=batch_size, num_workers=16)
72
73 data_out_train, _ = random_split(
74 tiny300k, [len(data_in_train), len(tiny300k) - len(data_in_train)]
75 )
76 train_loader = DataLoader(
77 data_in_train + data_out_train, batch_size=batch_size, shuffle=True, num_workers=16
78 )
Create DNN, pretrained on the imagenet excluding cifar10 classes. We have to replace the final layer to match the number of classes.
83 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar")
84 model.fc = torch.nn.Linear(model.fc.in_features, embedding_dim)
85 model.to(device)
86
87 opti = Adam(model.parameters())
88 criterion = MCHADLoss(
89 n_classes=10, n_dim=embedding_dim, weight_oe=0.01, weight_center=2, margin=margin
90 ).to(device)
91 scheduler = CosineAnnealingLR(opti, T_max=n_epochs * len(train_loader))
Define a function that evaluates the model
97 def test():
98 metrics = OODMetrics()
99 acc = Accuracy(num_classes=10, task="multiclass")
100
101 model.eval()
102
103 with torch.no_grad(), tqdm(test_loader, desc="Testing") as bar:
104 for x, y in bar:
105 # calculate embeddings
106 z = model(x.to(device))
107 # calculate the distance of each embedding to each center
108 distances = criterion.distance(z).cpu()
109 metrics.update(distances.min(dim=1).values, y)
110 known = is_known(y)
111 if known.any():
112 acc.update(distances[known].min(dim=1).indices, y[known])
113
114 print(metrics.compute())
115 print(f"Accuracy: {acc.compute().item():.2%}")
116 model.train()
Start training
122 for epoch in range(n_epochs):
123 loss_ema = None
124
125 with tqdm(train_loader, desc=f"Epoch {epoch}") as bar:
126 for x, y in bar:
127 # calculate embeddings
128 z = model(x.to(device))
129 # calculate the distance of each embedding to each center
130 distances = criterion.distance(z)
131 # calculate MCHAD loss, based on distances to centers
132 loss = criterion(distances, y.cuda())
133 opti.zero_grad()
134 loss.backward()
135 opti.step()
136 scheduler.step()
137
138 loss_ema = loss.item() if not loss_ema else 0.99 * loss_ema + 0.01 * loss.item()
139 bar.set_postfix_str(f"loss: {loss_ema:.3f} lr: {scheduler.get_last_lr()[0]:.6f}")
140
141 test()
142
143 # create new random split
144 data_out_train, _ = random_split(
145 tiny300k, [len(data_in_train), len(tiny300k) - len(data_in_train)]
146 )
147 train_loader = DataLoader(
148 data_in_train + data_out_train,
149 batch_size=batch_size,
150 shuffle=True,
151 num_workers=16,
152 )