Note
Go to the end to download the full example code.
Newsgroups 20
Benchmark code for Newsgroups 20. We test the models against three different Text dataset and calculate the mean performance.
Uses GRU model from the OOD detection baseline paper.
The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.
Detector |
AUROC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|
KLMatching |
80.17 |
84.72 |
70.21 |
54.91 |
MaxSoftmax |
80.26 |
84.67 |
72.13 |
53.39 |
Entropy |
84.14 |
87.46 |
75.37 |
51.67 |
ViM |
87.00 |
88.94 |
81.77 |
39.80 |
MaxLogit |
88.32 |
89.59 |
81.95 |
39.24 |
EnergyBased |
88.82 |
89.85 |
82.61 |
37.96 |
Mahalanobis |
89.06 |
89.40 |
85.56 |
35.24 |
31 import pandas as pd
32 import torch
33 import torch.nn.functional as F
34 from torch.utils.data import DataLoader
35 from torchtext.data.utils import get_tokenizer
36 from torchtext.vocab import build_vocab_from_iterator
37 from tqdm import tqdm
38
39 from pytorch_ood.dataset.txt import Multi30k, NewsGroup20, Reuters52, WMT16Sentences
40 from pytorch_ood.detector import (
41 ODIN,
42 EnergyBased,
43 Entropy,
44 KLMatching,
45 Mahalanobis,
46 MaxLogit,
47 MaxSoftmax,
48 ViM,
49 )
50 from pytorch_ood.model import GRUClassifier
51 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
52
53 fix_random_seed(123)
54
55 n_epochs = 10
56 lr = 0.001
57 device = "cuda:0"
58 root = "data"
63 # download datasets
64 train_dataset = NewsGroup20(root, train=True, download=True)
65
66 tokenizer = get_tokenizer("basic_english")
67
68
69 def yield_tokens(data_iter):
70 for text, _ in data_iter:
71 yield tokenizer(text)
72
73
74 vocab = build_vocab_from_iterator(yield_tokens(train_dataset))
75 vocab.set_default_index(0)
76
77
78 def prep(x):
79 return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
83 train_dataset = NewsGroup20(root, train=True, transform=prep)
84 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
Add padding, etc.
90 def collate_batch(batch):
91 texts = [i[0] for i in batch]
92 labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
93 t_lengths = torch.tensor([len(t) for t in texts])
94 max_t_length = torch.max(t_lengths)
95
96 padded = []
97 for text in texts:
98 t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
99 padded.append(t)
100 return torch.stack(padded, dim=0), labels
101
102
103 loader_in_train = DataLoader(train_dataset, batch_size=20, shuffle=True, collate_fn=collate_batch)
104 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
107 print("STAGE 1: Train Model")
108 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
109 model.to(device)
110
111 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
112 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
116 for epoch in range(n_epochs):
117 print(f"Epoch {epoch}")
118
119 model.train()
120 loss_ema = None
121 correct = 0
122 total = 0
123
124 model.train()
125
126 bar = tqdm(loader_in_train)
127
128 for n, batch in enumerate(bar):
129 inputs, labels = batch
130
131 inputs = inputs.to(device)
132 labels = labels.to(device)
133 logits = model(inputs)
134 loss = F.cross_entropy(logits, labels)
135
136 optimizer.zero_grad()
137 loss.backward()
138 optimizer.step()
139
140 loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
141
142 pred = logits.max(dim=1).indices
143 correct += pred.eq(labels).sum().data.cpu().item()
144 total += pred.shape[0]
145
146 bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
147
148 with torch.no_grad():
149 model.eval()
150 correct = 0
151 total = 0
152 for n, batch in enumerate(loader_in_test):
153 inputs, labels = batch
154
155 inputs = inputs.cuda()
156 labels = labels.cuda()
157 logits = model(inputs)
158 pred = logits.max(dim=1).indices
159 correct += pred.eq(labels).sum().data.cpu().item()
160 total += pred.shape[0]
161
162 print(f"Test Accuracy: {correct / total:.2%}")
166 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
167
168 datasets = {}
169
170 for ood_dataset in ood_datasets:
171 dataset_out_test = ood_dataset(
172 root="data", transform=prep, target_transform=ToUnknown(), download=True
173 )
174 test_loader = DataLoader(
175 dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
176 )
177 datasets[ood_dataset.__name__] = test_loader
Fit detectors to training data (some require this, some do not)
181 print("STAGE 2: Creating OOD Detectors")
182
183 detectors = {}
184 detectors["Entropy"] = Entropy(model)
185 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
186 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
187 detectors["KLMatching"] = KLMatching(model)
188 detectors["MaxSoftmax"] = MaxSoftmax(model)
189 detectors["EnergyBased"] = EnergyBased(model)
190 detectors["MaxLogit"] = MaxLogit(model)
191
192
193 print(f"> Fitting {len(detectors)} detectors")
194
195 for name, detector in detectors.items():
196 print(f"--> Fitting {name}")
197 detector.fit(loader_in_train, device=device)
200 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
201 results = []
202
203 with torch.no_grad():
204 for detector_name, detector in detectors.items():
205 print(f"> Evaluating {detector_name}")
206 for dataset_name, loader in datasets.items():
207 print(f"--> {dataset_name}")
208 metrics = OODMetrics()
209 for x, y in loader:
210 metrics.update(detector(x.to(device)), y.to(device))
211
212 r = {"Detector": detector_name, "Dataset": dataset_name}
213
214 r.update(metrics.compute())
215 results.append(r)
216
217
218 # calculate mean scores over all datasets, use percent
219
220 df = pd.DataFrame(results)
221 mean_scores = df.groupby("Detector").mean() * 100
222 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))