Newsgroups Outlier Exposure

Benchmark code for Newsgroups 20, trained with Outlier Exposure on the WikiText2 dataset. We test the detectors against three different Text dataset and calculate the mean performance.

Uses GRU model from the OOD detection baseline paper.

The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.

Detector

AUROC

AUPR-IN

AUPR-OUT

FPR95TPR

ViM

57.87

67.14

65.42

53.18

Mahalanobis

63.27

68.39

69.40

50.52

KLMatching

92.92

91.70

93.17

21.32

MaxSoftmax

93.55

92.27

94.50

20.17

Entropy

94.33

93.05

95.10

19.76

EnergyBased

94.63

93.14

95.67

16.90

MaxLogit

94.66

93.14

95.66

17.18

34 import pandas as pd
35 import torch
36 from torch.utils.data import DataLoader
37 from torchtext.data.utils import get_tokenizer
38 from torchtext.vocab import build_vocab_from_iterator
39 from tqdm import tqdm
40
41 from pytorch_ood.dataset.txt import (
42     Multi30k,
43     NewsGroup20,
44     Reuters52,
45     WikiText2,
46     WMT16Sentences,
47 )
48 from pytorch_ood.detector import (
49     EnergyBased,
50     Entropy,
51     KLMatching,
52     Mahalanobis,
53     MaxLogit,
54     MaxSoftmax,
55     ViM,
56 )
57 from pytorch_ood.loss import OutlierExposureLoss
58 from pytorch_ood.model import GRUClassifier
59 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
60
61 fix_random_seed(123)
62
63 n_epochs = 5
64 lr = 0.001
65 device = "cuda:0"
66 root = "data"

download datasets

70 train_dataset_in = NewsGroup20(root, train=True, download=True)
71
72 tokenizer = get_tokenizer("basic_english")
73
74
75 def yield_tokens(data_iter):
76     for text, _ in data_iter:
77         yield tokenizer(text)
78
79
80 vocab = build_vocab_from_iterator(yield_tokens(train_dataset_in))
81 vocab.set_default_index(0)
82
83
84 def prep(x):
85     return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
89 train_dataset_in = NewsGroup20(root, train=True, transform=prep)
90 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
91 train_ood_dataset = WikiText2(
92     root, split="train", download=True, transform=prep, target_transform=ToUnknown()
93 )

Add padding, etc.

 99 def collate_batch(batch):
100     texts = [i[0] for i in batch]
101     labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
102     t_lengths = torch.tensor([len(t) for t in texts])
103     max_t_length = torch.max(t_lengths)
104
105     padded = []
106     for text in texts:
107         t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
108         padded.append(t)
109     return torch.stack(padded, dim=0), labels
110
111
112 loader_train = DataLoader(
113     train_dataset_in + train_ood_dataset,
114     batch_size=20,
115     shuffle=True,
116     collate_fn=collate_batch,
117 )
118 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
121 print("STAGE 1: Train Model")
122 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
123
124 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
125 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
126 criterion = OutlierExposureLoss()
127
128 model.to(device)
131 for epoch in range(n_epochs):
132     print(f"Epoch {epoch}")
133
134     model.train()
135     loss_ema = None
136     correct = 0
137     total = 0
138
139     bar = tqdm(loader_train)
140
141     for n, batch in enumerate(bar):
142         inputs, labels = batch
143
144         inputs = inputs.to(device)
145         labels = labels.to(device)
146         logits = model(inputs)
147         loss = criterion(logits, labels)
148
149         optimizer.zero_grad()
150         loss.backward()
151         optimizer.step()
152
153         loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
154
155         pred = logits.max(dim=1).indices
156         correct += pred[is_known(labels)].eq(labels[is_known(labels)]).sum().data.cpu().item()
157         total += is_known(labels).sum()
158
159         bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
160
161     with torch.no_grad():
162         model.eval()
163         correct = 0
164         total = 0
165         for n, batch in enumerate(loader_in_test):
166             inputs, labels = batch
167
168             inputs = inputs.cuda()
169             labels = labels.cuda()
170             logits = model(inputs)
171             pred = logits.max(dim=1).indices
172             correct += pred.eq(labels).sum().data.cpu().item()
173             total += pred.shape[0]
174
175         print(f"Test Accuracy: {correct / total:.2%}")
179 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
180
181 datasets = {}
182
183 for ood_dataset in ood_datasets:
184     dataset_out_test = ood_dataset(
185         root="data", transform=prep, target_transform=ToUnknown(), download=True
186     )
187     test_loader = DataLoader(
188         dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
189     )
190     datasets[ood_dataset.__name__] = test_loader

Fit detectors to training data (some require this, some do not)

194 print("STAGE 2: Creating OOD Detectors")
195
196 detectors = {}
197 detectors["Entropy"] = Entropy(model)
198 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
199 detectors["Mahalanobis"] = Mahalanobis(model.features)
200 detectors["KLMatching"] = KLMatching(model)
201 detectors["MaxSoftmax"] = MaxSoftmax(model)
202 detectors["EnergyBased"] = EnergyBased(model)
203 detectors["MaxLogit"] = MaxLogit(model)
204
205
206 print(f"> Fitting {len(detectors)} detectors")
207
208 for name, detector in detectors.items():
209     print(f"--> Fitting {name}")
210     detector.to(device)
211     detector.fit(loader_train)
214 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
215 results = []
216
217 with torch.no_grad():
218     for detector_name, detector in detectors.items():
219         print(f"> Evaluating {detector_name}")
220         for dataset_name, loader in datasets.items():
221             print(f"--> {dataset_name}")
222             metrics = OODMetrics()
223             for x, y in loader:
224                 metrics.update(detector(x.to(device)), y.to(device))
225
226             r = {"Detector": detector_name, "Dataset": dataset_name}
227
228             r.update(metrics.compute())
229             results.append(r)
234 df = pd.DataFrame(results)
235 mean_scores = df.groupby("Detector").mean() * 100
236 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery