Newsgroups Outlier Exposure

Benchmark code for Newsgroups 20, trained with Outlier Exposure on the WikiText2 dataset. We test the detectors against three different Text dataset and calculate the mean performance.

Uses GRU model from the OOD detection baseline paper.

The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.

Detector

AUROC

AUPR-IN

AUPR-OUT

FPR95TPR

ViM

57.87

67.14

65.42

53.18

Mahalanobis

63.27

68.39

69.40

50.52

KLMatching

92.92

91.70

93.17

21.32

MaxSoftmax

93.55

92.27

94.50

20.17

Entropy

94.33

93.05

95.10

19.76

EnergyBased

94.63

93.14

95.67

16.90

MaxLogit

94.66

93.14

95.66

17.18

34 import pandas as pd
35 import torch
36 from torch.utils.data import DataLoader
37 from torchtext.data.utils import get_tokenizer
38 from torchtext.vocab import build_vocab_from_iterator
39 from tqdm import tqdm
40
41 from pytorch_ood.dataset.txt import (
42     Multi30k,
43     NewsGroup20,
44     Reuters52,
45     WikiText2,
46     WMT16Sentences,
47 )
48 from pytorch_ood.detector import (
49     ODIN,
50     EnergyBased,
51     Entropy,
52     KLMatching,
53     Mahalanobis,
54     MaxLogit,
55     MaxSoftmax,
56     ViM,
57 )
58 from pytorch_ood.loss import OutlierExposureLoss
59 from pytorch_ood.model import GRUClassifier
60 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
61
62 fix_random_seed(123)
63
64 n_epochs = 5
65 lr = 0.001
66 device = "cuda:0"
67 root = "data"

download datasets

71 train_dataset_in = NewsGroup20(root, train=True, download=True)
72
73 tokenizer = get_tokenizer("basic_english")
74
75
76 def yield_tokens(data_iter):
77     for text, _ in data_iter:
78         yield tokenizer(text)
79
80
81 vocab = build_vocab_from_iterator(yield_tokens(train_dataset_in))
82 vocab.set_default_index(0)
83
84
85 def prep(x):
86     return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
90 train_dataset_in = NewsGroup20(root, train=True, transform=prep)
91 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
92 train_ood_dataset = WikiText2(
93     root, split="train", download=True, transform=prep, target_transform=ToUnknown()
94 )

Add padding, etc.

100 def collate_batch(batch):
101     texts = [i[0] for i in batch]
102     labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
103     t_lengths = torch.tensor([len(t) for t in texts])
104     max_t_length = torch.max(t_lengths)
105
106     padded = []
107     for text in texts:
108         t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
109         padded.append(t)
110     return torch.stack(padded, dim=0), labels
111
112
113 loader_train = DataLoader(
114     train_dataset_in + train_ood_dataset,
115     batch_size=20,
116     shuffle=True,
117     collate_fn=collate_batch,
118 )
119 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
122 print("STAGE 1: Train Model")
123 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
124
125 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
126 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
127 criterion = OutlierExposureLoss()
128
129 model.to(device)
132 for epoch in range(n_epochs):
133     print(f"Epoch {epoch}")
134
135     model.train()
136     loss_ema = None
137     correct = 0
138     total = 0
139
140     bar = tqdm(loader_train)
141
142     for n, batch in enumerate(bar):
143         inputs, labels = batch
144
145         inputs = inputs.to(device)
146         labels = labels.to(device)
147         logits = model(inputs)
148         loss = criterion(logits, labels)
149
150         optimizer.zero_grad()
151         loss.backward()
152         optimizer.step()
153
154         loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
155
156         pred = logits.max(dim=1).indices
157         correct += pred[is_known(labels)].eq(labels[is_known(labels)]).sum().data.cpu().item()
158         total += is_known(labels).sum()
159
160         bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
161
162     with torch.no_grad():
163         model.eval()
164         correct = 0
165         total = 0
166         for n, batch in enumerate(loader_in_test):
167             inputs, labels = batch
168
169             inputs = inputs.cuda()
170             labels = labels.cuda()
171             logits = model(inputs)
172             pred = logits.max(dim=1).indices
173             correct += pred.eq(labels).sum().data.cpu().item()
174             total += pred.shape[0]
175
176         print(f"Test Accuracy: {correct / total:.2%}")
180 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
181
182 datasets = {}
183
184 for ood_dataset in ood_datasets:
185     dataset_out_test = ood_dataset(
186         root="data", transform=prep, target_transform=ToUnknown(), download=True
187     )
188     test_loader = DataLoader(
189         dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
190     )
191     datasets[ood_dataset.__name__] = test_loader

Fit detectors to training data (some require this, some do not)

195 print("STAGE 2: Creating OOD Detectors")
196
197 detectors = {}
198 detectors["Entropy"] = Entropy(model)
199 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
200 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
201 detectors["KLMatching"] = KLMatching(model)
202 detectors["MaxSoftmax"] = MaxSoftmax(model)
203 detectors["EnergyBased"] = EnergyBased(model)
204 detectors["MaxLogit"] = MaxLogit(model)
205
206
207 print(f"> Fitting {len(detectors)} detectors")
208
209 for name, detector in detectors.items():
210     print(f"--> Fitting {name}")
211     detector.fit(loader_train, device=device)
214 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
215 results = []
216
217 with torch.no_grad():
218     for detector_name, detector in detectors.items():
219         print(f"> Evaluating {detector_name}")
220         for dataset_name, loader in datasets.items():
221             print(f"--> {dataset_name}")
222             metrics = OODMetrics()
223             for x, y in loader:
224                 metrics.update(detector(x.to(device)), y.to(device))
225
226             r = {"Detector": detector_name, "Dataset": dataset_name}
227
228             r.update(metrics.compute())
229             results.append(r)
234 df = pd.DataFrame(results)
235 mean_scores = df.groupby("Detector").mean() * 100
236 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery