Note
Go to the end to download the full example code.
Newsgroups 20
Benchmark code for Newsgroups 20. We test the models against three different Text dataset and calculate the mean performance.
Uses GRU model from the OOD detection baseline paper.
The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
KLMatching |
80.24 |
41.52 |
72.06 |
83.54 |
55.63 |
MaxSoftmax |
80.47 |
35.98 |
74.39 |
83.86 |
53.85 |
Mahalanobis |
83.53 |
40.63 |
79.72 |
84.03 |
46.28 |
Entropy |
83.61 |
36.23 |
77.28 |
85.89 |
51.74 |
ViM |
84.19 |
41.52 |
81.51 |
84.42 |
42.44 |
MaxLogit |
87.86 |
38.30 |
85.11 |
87.65 |
36.81 |
EnergyBased |
88.35 |
38.17 |
85.84 |
87.83 |
35.52 |
33 import pandas as pd
34 import torch
35 import torch.nn.functional as F
36 from torch.utils.data import DataLoader
37 from torchtext.data.utils import get_tokenizer
38 from torchtext.vocab import build_vocab_from_iterator
39 from tqdm import tqdm
40
41 from pytorch_ood.dataset.txt import Multi30k, NewsGroup20, Reuters52, WMT16Sentences
42 from pytorch_ood.detector import (
43 EnergyBased,
44 Entropy,
45 KLMatching,
46 Mahalanobis,
47 MaxLogit,
48 MaxSoftmax,
49 ViM,
50 )
51 from pytorch_ood.model import GRUClassifier
52 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
53
54 fix_random_seed(123)
55
56 n_epochs = 10
57 lr = 0.001
58 device = "cuda:0"
59 root = "data"
64 # download datasets
65 train_dataset = NewsGroup20(root, train=True, download=True)
66
67 tokenizer = get_tokenizer("basic_english")
68
69
70 def yield_tokens(data_iter):
71 for text, _ in data_iter:
72 yield tokenizer(text)
73
74
75 vocab = build_vocab_from_iterator(yield_tokens(train_dataset))
76 vocab.set_default_index(0)
77
78
79 def prep(x):
80 return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
84 train_dataset = NewsGroup20(root, train=True, transform=prep)
85 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
Add padding, etc.
91 def collate_batch(batch):
92 texts = [i[0] for i in batch]
93 labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
94 t_lengths = torch.tensor([len(t) for t in texts])
95 max_t_length = torch.max(t_lengths)
96
97 padded = []
98 for text in texts:
99 t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
100 padded.append(t)
101 return torch.stack(padded, dim=0), labels
102
103
104 loader_in_train = DataLoader(train_dataset, batch_size=20, shuffle=True, collate_fn=collate_batch)
105 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
108 print("STAGE 1: Train Model")
109 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
110 model.to(device)
111
112 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
113 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
117 for epoch in range(n_epochs):
118 print(f"Epoch {epoch}")
119
120 model.train()
121 loss_ema = None
122 correct = 0
123 total = 0
124
125 model.train()
126
127 bar = tqdm(loader_in_train)
128
129 for n, batch in enumerate(bar):
130 inputs, labels = batch
131
132 inputs = inputs.to(device)
133 labels = labels.to(device)
134 logits = model(inputs)
135 loss = F.cross_entropy(logits, labels)
136
137 optimizer.zero_grad()
138 loss.backward()
139 optimizer.step()
140
141 loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
142
143 pred = logits.max(dim=1).indices
144 correct += pred.eq(labels).sum().data.cpu().item()
145 total += pred.shape[0]
146
147 bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
148
149 with torch.no_grad():
150 model.eval()
151 correct = 0
152 total = 0
153 for n, batch in enumerate(loader_in_test):
154 inputs, labels = batch
155
156 inputs = inputs.cuda()
157 labels = labels.cuda()
158 logits = model(inputs)
159 pred = logits.max(dim=1).indices
160 correct += pred.eq(labels).sum().data.cpu().item()
161 total += pred.shape[0]
162
163 print(f"Test Accuracy: {correct / total:.2%}")
167 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
168
169 datasets = {}
170
171 for ood_dataset in ood_datasets:
172 dataset_out_test = ood_dataset(
173 root="data", transform=prep, target_transform=ToUnknown(), download=True
174 )
175 test_loader = DataLoader(
176 dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
177 )
178 datasets[ood_dataset.__name__] = test_loader
Fit detectors to training data (some require this, some do not)
182 print("STAGE 2: Creating OOD Detectors")
183
184 detectors = {}
185 detectors["Entropy"] = Entropy(model)
186 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
187 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
188 detectors["KLMatching"] = KLMatching(model)
189 detectors["MaxSoftmax"] = MaxSoftmax(model)
190 detectors["EnergyBased"] = EnergyBased(model)
191 detectors["MaxLogit"] = MaxLogit(model)
192
193
194 print(f"> Fitting {len(detectors)} detectors")
195
196 for name, detector in detectors.items():
197 print(f"--> Fitting {name}")
198 detector.to(device)
199 detector.fit(loader_in_train)
202 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
203 results = []
204
205 with torch.no_grad():
206 for detector_name, detector in detectors.items():
207 print(f"> Evaluating {detector_name}")
208 for dataset_name, loader in datasets.items():
209 print(f"--> {dataset_name}")
210 metrics = OODMetrics()
211 for x, y in loader:
212 metrics.update(detector(x.to(device)), y.to(device))
213
214 r = {"Detector": detector_name, "Dataset": dataset_name}
215
216 r.update(metrics.compute())
217 results.append(r)
218
219
220 # calculate mean scores over all datasets, use percent
221
222 df = pd.DataFrame(results)
223 mean_scores = df.groupby("Detector").mean() * 100
224 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))