Newsgroups 20

Benchmark code for Newsgroups 20. We test the models against three different Text dataset and calculate the mean performance.

Uses GRU model from the OOD detection baseline paper.

The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.

Detector

AUROC

AUPR-IN

AUPR-OUT

FPR95TPR

KLMatching

80.17

84.72

70.21

54.91

MaxSoftmax

80.26

84.67

72.13

53.39

Entropy

84.14

87.46

75.37

51.67

ViM

87.00

88.94

81.77

39.80

MaxLogit

88.32

89.59

81.95

39.24

EnergyBased

88.82

89.85

82.61

37.96

Mahalanobis

89.06

89.40

85.56

35.24

31 import pandas as pd
32 import torch
33 import torch.nn.functional as F
34 from torch.utils.data import DataLoader
35 from torchtext.data.utils import get_tokenizer
36 from torchtext.vocab import build_vocab_from_iterator
37 from tqdm import tqdm
38
39 from pytorch_ood.dataset.txt import Multi30k, NewsGroup20, Reuters52, WMT16Sentences
40 from pytorch_ood.detector import (
41     ODIN,
42     EnergyBased,
43     Entropy,
44     KLMatching,
45     Mahalanobis,
46     MaxLogit,
47     MaxSoftmax,
48     ViM,
49 )
50 from pytorch_ood.model import GRUClassifier
51 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
52
53 fix_random_seed(123)
54
55 n_epochs = 10
56 lr = 0.001
57 device = "cuda:0"
58 root = "data"
63 # download datasets
64 train_dataset = NewsGroup20(root, train=True, download=True)
65
66 tokenizer = get_tokenizer("basic_english")
67
68
69 def yield_tokens(data_iter):
70     for text, _ in data_iter:
71         yield tokenizer(text)
72
73
74 vocab = build_vocab_from_iterator(yield_tokens(train_dataset))
75 vocab.set_default_index(0)
76
77
78 def prep(x):
79     return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
83 train_dataset = NewsGroup20(root, train=True, transform=prep)
84 dataset_in_test = NewsGroup20(root, train=False, transform=prep)

Add padding, etc.

 90 def collate_batch(batch):
 91     texts = [i[0] for i in batch]
 92     labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
 93     t_lengths = torch.tensor([len(t) for t in texts])
 94     max_t_length = torch.max(t_lengths)
 95
 96     padded = []
 97     for text in texts:
 98         t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
 99         padded.append(t)
100     return torch.stack(padded, dim=0), labels
101
102
103 loader_in_train = DataLoader(train_dataset, batch_size=20, shuffle=True, collate_fn=collate_batch)
104 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
107 print("STAGE 1: Train Model")
108 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
109 model.to(device)
110
111 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
112 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
116 for epoch in range(n_epochs):
117     print(f"Epoch {epoch}")
118
119     model.train()
120     loss_ema = None
121     correct = 0
122     total = 0
123
124     model.train()
125
126     bar = tqdm(loader_in_train)
127
128     for n, batch in enumerate(bar):
129         inputs, labels = batch
130
131         inputs = inputs.to(device)
132         labels = labels.to(device)
133         logits = model(inputs)
134         loss = F.cross_entropy(logits, labels)
135
136         optimizer.zero_grad()
137         loss.backward()
138         optimizer.step()
139
140         loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
141
142         pred = logits.max(dim=1).indices
143         correct += pred.eq(labels).sum().data.cpu().item()
144         total += pred.shape[0]
145
146         bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
147
148     with torch.no_grad():
149         model.eval()
150         correct = 0
151         total = 0
152         for n, batch in enumerate(loader_in_test):
153             inputs, labels = batch
154
155             inputs = inputs.cuda()
156             labels = labels.cuda()
157             logits = model(inputs)
158             pred = logits.max(dim=1).indices
159             correct += pred.eq(labels).sum().data.cpu().item()
160             total += pred.shape[0]
161
162         print(f"Test Accuracy: {correct / total:.2%}")
166 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
167
168 datasets = {}
169
170 for ood_dataset in ood_datasets:
171     dataset_out_test = ood_dataset(
172         root="data", transform=prep, target_transform=ToUnknown(), download=True
173     )
174     test_loader = DataLoader(
175         dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
176     )
177     datasets[ood_dataset.__name__] = test_loader

Fit detectors to training data (some require this, some do not)

181 print("STAGE 2: Creating OOD Detectors")
182
183 detectors = {}
184 detectors["Entropy"] = Entropy(model)
185 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
186 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
187 detectors["KLMatching"] = KLMatching(model)
188 detectors["MaxSoftmax"] = MaxSoftmax(model)
189 detectors["EnergyBased"] = EnergyBased(model)
190 detectors["MaxLogit"] = MaxLogit(model)
191
192
193 print(f"> Fitting {len(detectors)} detectors")
194
195 for name, detector in detectors.items():
196     print(f"--> Fitting {name}")
197     detector.fit(loader_in_train, device=device)
200 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
201 results = []
202
203 with torch.no_grad():
204     for detector_name, detector in detectors.items():
205         print(f"> Evaluating {detector_name}")
206         for dataset_name, loader in datasets.items():
207             print(f"--> {dataset_name}")
208             metrics = OODMetrics()
209             for x, y in loader:
210                 metrics.update(detector(x.to(device)), y.to(device))
211
212             r = {"Detector": detector_name, "Dataset": dataset_name}
213
214             r.update(metrics.compute())
215             results.append(r)
216
217
218 # calculate mean scores over all datasets, use percent
219
220 df = pd.DataFrame(results)
221 mean_scores = df.groupby("Detector").mean() * 100
222 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery