Newsgroups 20

Benchmark code for Newsgroups 20. We test the models against three different Text dataset and calculate the mean performance.

Uses GRU model from the OOD detection baseline paper.

The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

KLMatching

80.24

41.52

72.06

83.54

55.63

MaxSoftmax

80.47

35.98

74.39

83.86

53.85

Mahalanobis

83.53

40.63

79.72

84.03

46.28

Entropy

83.61

36.23

77.28

85.89

51.74

ViM

84.19

41.52

81.51

84.42

42.44

MaxLogit

87.86

38.30

85.11

87.65

36.81

EnergyBased

88.35

38.17

85.84

87.83

35.52

33 import pandas as pd
34 import torch
35 import torch.nn.functional as F
36 from torch.utils.data import DataLoader
37 from torchtext.data.utils import get_tokenizer
38 from torchtext.vocab import build_vocab_from_iterator
39 from tqdm import tqdm
40
41 from pytorch_ood.dataset.txt import Multi30k, NewsGroup20, Reuters52, WMT16Sentences
42 from pytorch_ood.detector import (
43     EnergyBased,
44     Entropy,
45     KLMatching,
46     Mahalanobis,
47     MaxLogit,
48     MaxSoftmax,
49     ViM,
50 )
51 from pytorch_ood.model import GRUClassifier
52 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
53
54 fix_random_seed(123)
55
56 n_epochs = 10
57 lr = 0.001
58 device = "cuda:0"
59 root = "data"
64 # download datasets
65 train_dataset = NewsGroup20(root, train=True, download=True)
66
67 tokenizer = get_tokenizer("basic_english")
68
69
70 def yield_tokens(data_iter):
71     for text, _ in data_iter:
72         yield tokenizer(text)
73
74
75 vocab = build_vocab_from_iterator(yield_tokens(train_dataset))
76 vocab.set_default_index(0)
77
78
79 def prep(x):
80     return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
84 train_dataset = NewsGroup20(root, train=True, transform=prep)
85 dataset_in_test = NewsGroup20(root, train=False, transform=prep)

Add padding, etc.

 91 def collate_batch(batch):
 92     texts = [i[0] for i in batch]
 93     labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
 94     t_lengths = torch.tensor([len(t) for t in texts])
 95     max_t_length = torch.max(t_lengths)
 96
 97     padded = []
 98     for text in texts:
 99         t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
100         padded.append(t)
101     return torch.stack(padded, dim=0), labels
102
103
104 loader_in_train = DataLoader(train_dataset, batch_size=20, shuffle=True, collate_fn=collate_batch)
105 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
108 print("STAGE 1: Train Model")
109 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
110 model.to(device)
111
112 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
113 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
117 for epoch in range(n_epochs):
118     print(f"Epoch {epoch}")
119
120     model.train()
121     loss_ema = None
122     correct = 0
123     total = 0
124
125     model.train()
126
127     bar = tqdm(loader_in_train)
128
129     for n, batch in enumerate(bar):
130         inputs, labels = batch
131
132         inputs = inputs.to(device)
133         labels = labels.to(device)
134         logits = model(inputs)
135         loss = F.cross_entropy(logits, labels)
136
137         optimizer.zero_grad()
138         loss.backward()
139         optimizer.step()
140
141         loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
142
143         pred = logits.max(dim=1).indices
144         correct += pred.eq(labels).sum().data.cpu().item()
145         total += pred.shape[0]
146
147         bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
148
149     with torch.no_grad():
150         model.eval()
151         correct = 0
152         total = 0
153         for n, batch in enumerate(loader_in_test):
154             inputs, labels = batch
155
156             inputs = inputs.cuda()
157             labels = labels.cuda()
158             logits = model(inputs)
159             pred = logits.max(dim=1).indices
160             correct += pred.eq(labels).sum().data.cpu().item()
161             total += pred.shape[0]
162
163         print(f"Test Accuracy: {correct / total:.2%}")
167 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
168
169 datasets = {}
170
171 for ood_dataset in ood_datasets:
172     dataset_out_test = ood_dataset(
173         root="data", transform=prep, target_transform=ToUnknown(), download=True
174     )
175     test_loader = DataLoader(
176         dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
177     )
178     datasets[ood_dataset.__name__] = test_loader

Fit detectors to training data (some require this, some do not)

182 print("STAGE 2: Creating OOD Detectors")
183
184 detectors = {}
185 detectors["Entropy"] = Entropy(model)
186 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
187 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
188 detectors["KLMatching"] = KLMatching(model)
189 detectors["MaxSoftmax"] = MaxSoftmax(model)
190 detectors["EnergyBased"] = EnergyBased(model)
191 detectors["MaxLogit"] = MaxLogit(model)
192
193
194 print(f"> Fitting {len(detectors)} detectors")
195
196 for name, detector in detectors.items():
197     print(f"--> Fitting {name}")
198     detector.fit(loader_in_train, device=device)
201 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
202 results = []
203
204 with torch.no_grad():
205     for detector_name, detector in detectors.items():
206         print(f"> Evaluating {detector_name}")
207         for dataset_name, loader in datasets.items():
208             print(f"--> {dataset_name}")
209             metrics = OODMetrics()
210             for x, y in loader:
211                 metrics.update(detector(x.to(device)), y.to(device))
212
213             r = {"Detector": detector_name, "Dataset": dataset_name}
214
215             r.update(metrics.compute())
216             results.append(r)
217
218
219 # calculate mean scores over all datasets, use percent
220
221 df = pd.DataFrame(results)
222 mean_scores = df.groupby("Detector").mean() * 100
223 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery