Note
Go to the end to download the full example code.
Newsgroups 20
Benchmark code for Newsgroups 20. We test the models against three different Text dataset and calculate the mean performance.
Uses GRU model from the OOD detection baseline paper.
The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|
KLMatching |
80.24 |
41.52 |
72.06 |
83.54 |
55.63 |
MaxSoftmax |
80.47 |
35.98 |
74.39 |
83.86 |
53.85 |
Mahalanobis |
83.53 |
40.63 |
79.72 |
84.03 |
46.28 |
Entropy |
83.61 |
36.23 |
77.28 |
85.89 |
51.74 |
ViM |
84.19 |
41.52 |
81.51 |
84.42 |
42.44 |
MaxLogit |
87.86 |
38.30 |
85.11 |
87.65 |
36.81 |
EnergyBased |
88.35 |
38.17 |
85.84 |
87.83 |
35.52 |
33 import pandas as pd
34 import torch
35 import torch.nn.functional as F
36 from torch.utils.data import DataLoader
37 from torchtext.data.utils import get_tokenizer
38 from torchtext.vocab import build_vocab_from_iterator
39 from tqdm import tqdm
40
41 from pytorch_ood.dataset.txt import Multi30k, NewsGroup20, Reuters52, WMT16Sentences
42 from pytorch_ood.detector import (
43 EnergyBased,
44 Entropy,
45 KLMatching,
46 Mahalanobis,
47 MaxLogit,
48 MaxSoftmax,
49 ViM,
50 )
51 from pytorch_ood.model import GRUClassifier
52 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
53
54 fix_random_seed(123)
55
56 n_epochs = 10
57 lr = 0.001
58 device = "cuda:0"
59 root = "data"
64 # download datasets
65 train_dataset = NewsGroup20(root, train=True, download=True)
66
67 tokenizer = get_tokenizer("basic_english")
68
69
70 def yield_tokens(data_iter):
71 for text, _ in data_iter:
72 yield tokenizer(text)
73
74
75 vocab = build_vocab_from_iterator(yield_tokens(train_dataset))
76 vocab.set_default_index(0)
77
78
79 def prep(x):
80 return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
84 train_dataset = NewsGroup20(root, train=True, transform=prep)
85 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
Add padding, etc.
91 def collate_batch(batch):
92 texts = [i[0] for i in batch]
93 labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
94 t_lengths = torch.tensor([len(t) for t in texts])
95 max_t_length = torch.max(t_lengths)
96
97 padded = []
98 for text in texts:
99 t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
100 padded.append(t)
101 return torch.stack(padded, dim=0), labels
102
103
104 loader_in_train = DataLoader(train_dataset, batch_size=20, shuffle=True, collate_fn=collate_batch)
105 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
108 print("STAGE 1: Train Model")
109 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
110 model.to(device)
111
112 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
113 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
117 for epoch in range(n_epochs):
118 print(f"Epoch {epoch}")
119
120 model.train()
121 loss_ema = None
122 correct = 0
123 total = 0
124
125 model.train()
126
127 bar = tqdm(loader_in_train)
128
129 for n, batch in enumerate(bar):
130 inputs, labels = batch
131
132 inputs = inputs.to(device)
133 labels = labels.to(device)
134 logits = model(inputs)
135 loss = F.cross_entropy(logits, labels)
136
137 optimizer.zero_grad()
138 loss.backward()
139 optimizer.step()
140
141 loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
142
143 pred = logits.max(dim=1).indices
144 correct += pred.eq(labels).sum().data.cpu().item()
145 total += pred.shape[0]
146
147 bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
148
149 with torch.no_grad():
150 model.eval()
151 correct = 0
152 total = 0
153 for n, batch in enumerate(loader_in_test):
154 inputs, labels = batch
155
156 inputs = inputs.cuda()
157 labels = labels.cuda()
158 logits = model(inputs)
159 pred = logits.max(dim=1).indices
160 correct += pred.eq(labels).sum().data.cpu().item()
161 total += pred.shape[0]
162
163 print(f"Test Accuracy: {correct / total:.2%}")
167 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
168
169 datasets = {}
170
171 for ood_dataset in ood_datasets:
172 dataset_out_test = ood_dataset(
173 root="data", transform=prep, target_transform=ToUnknown(), download=True
174 )
175 test_loader = DataLoader(
176 dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
177 )
178 datasets[ood_dataset.__name__] = test_loader
Fit detectors to training data (some require this, some do not)
182 print("STAGE 2: Creating OOD Detectors")
183
184 detectors = {}
185 detectors["Entropy"] = Entropy(model)
186 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
187 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
188 detectors["KLMatching"] = KLMatching(model)
189 detectors["MaxSoftmax"] = MaxSoftmax(model)
190 detectors["EnergyBased"] = EnergyBased(model)
191 detectors["MaxLogit"] = MaxLogit(model)
192
193
194 print(f"> Fitting {len(detectors)} detectors")
195
196 for name, detector in detectors.items():
197 print(f"--> Fitting {name}")
198 detector.fit(loader_in_train, device=device)
201 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
202 results = []
203
204 with torch.no_grad():
205 for detector_name, detector in detectors.items():
206 print(f"> Evaluating {detector_name}")
207 for dataset_name, loader in datasets.items():
208 print(f"--> {dataset_name}")
209 metrics = OODMetrics()
210 for x, y in loader:
211 metrics.update(detector(x.to(device)), y.to(device))
212
213 r = {"Detector": detector_name, "Dataset": dataset_name}
214
215 r.update(metrics.compute())
216 results.append(r)
217
218
219 # calculate mean scores over all datasets, use percent
220
221 df = pd.DataFrame(results)
222 mean_scores = df.groupby("Detector").mean() * 100
223 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))