Newsgroups Outlier Exposure

Benchmark code for Newsgroups 20, trained with Outlier Exposure on the WikiText2 dataset. We test the detectors against three different Text dataset and calculate the mean performance.

Uses GRU model from the OOD detection baseline paper.

The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.

Detector

AUROC

AUPR-IN

AUPR-OUT

FPR95TPR

ViM

57.87

67.14

65.42

53.18

Mahalanobis

63.27

68.39

69.40

50.52

KLMatching

92.92

91.70

93.17

21.32

MaxSoftmax

93.55

92.27

94.50

20.17

Entropy

94.33

93.05

95.10

19.76

EnergyBased

94.63

93.14

95.67

16.90

MaxLogit

94.66

93.14

95.66

17.18

33 import pandas as pd
34 import torch
35 from torch.utils.data import DataLoader
36 from torchtext.data.utils import get_tokenizer
37 from torchtext.vocab import build_vocab_from_iterator
38 from tqdm import tqdm
39
40 from pytorch_ood.dataset.txt import Multi30k, NewsGroup20, Reuters52, WikiText2, WMT16Sentences
41 from pytorch_ood.detector import (
42     ODIN,
43     EnergyBased,
44     Entropy,
45     KLMatching,
46     Mahalanobis,
47     MaxLogit,
48     MaxSoftmax,
49     ViM,
50 )
51 from pytorch_ood.loss import OutlierExposureLoss
52 from pytorch_ood.model import GRUClassifier
53 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
54
55 fix_random_seed(123)
56
57 n_epochs = 5
58 lr = 0.001
59 device = "cuda:0"
60 root = "data"

download datasets

64 train_dataset_in = NewsGroup20(root, train=True, download=True)
65
66 tokenizer = get_tokenizer("basic_english")
67
68
69 def yield_tokens(data_iter):
70     for text, _ in data_iter:
71         yield tokenizer(text)
72
73
74 vocab = build_vocab_from_iterator(yield_tokens(train_dataset_in))
75 vocab.set_default_index(0)
76
77
78 def prep(x):
79     return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
83 train_dataset_in = NewsGroup20(root, train=True, transform=prep)
84 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
85 train_ood_dataset = WikiText2(
86     root, split="train", download=True, transform=prep, target_transform=ToUnknown()
87 )

Add padding, etc.

 93 def collate_batch(batch):
 94     texts = [i[0] for i in batch]
 95     labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
 96     t_lengths = torch.tensor([len(t) for t in texts])
 97     max_t_length = torch.max(t_lengths)
 98
 99     padded = []
100     for text in texts:
101         t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
102         padded.append(t)
103     return torch.stack(padded, dim=0), labels
104
105
106 loader_train = DataLoader(
107     train_dataset_in + train_ood_dataset,
108     batch_size=20,
109     shuffle=True,
110     collate_fn=collate_batch,
111 )
112 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
115 print("STAGE 1: Train Model")
116 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
117
118 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
119 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
120 criterion = OutlierExposureLoss()
121
122 model.to(device)
125 for epoch in range(n_epochs):
126     print(f"Epoch {epoch}")
127
128     model.train()
129     loss_ema = None
130     correct = 0
131     total = 0
132
133     bar = tqdm(loader_train)
134
135     for n, batch in enumerate(bar):
136         inputs, labels = batch
137
138         inputs = inputs.to(device)
139         labels = labels.to(device)
140         logits = model(inputs)
141         loss = criterion(logits, labels)
142
143         optimizer.zero_grad()
144         loss.backward()
145         optimizer.step()
146
147         loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
148
149         pred = logits.max(dim=1).indices
150         correct += pred[is_known(labels)].eq(labels[is_known(labels)]).sum().data.cpu().item()
151         total += is_known(labels).sum()
152
153         bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
154
155     with torch.no_grad():
156         model.eval()
157         correct = 0
158         total = 0
159         for n, batch in enumerate(loader_in_test):
160             inputs, labels = batch
161
162             inputs = inputs.cuda()
163             labels = labels.cuda()
164             logits = model(inputs)
165             pred = logits.max(dim=1).indices
166             correct += pred.eq(labels).sum().data.cpu().item()
167             total += pred.shape[0]
168
169         print(f"Test Accuracy: {correct / total:.2%}")
173 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
174
175 datasets = {}
176
177 for ood_dataset in ood_datasets:
178     dataset_out_test = ood_dataset(
179         root="data", transform=prep, target_transform=ToUnknown(), download=True
180     )
181     test_loader = DataLoader(
182         dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
183     )
184     datasets[ood_dataset.__name__] = test_loader

Fit detectors to training data (some require this, some do not)

188 print("STAGE 2: Creating OOD Detectors")
189
190 detectors = {}
191 detectors["Entropy"] = Entropy(model)
192 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
193 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
194 detectors["KLMatching"] = KLMatching(model)
195 detectors["MaxSoftmax"] = MaxSoftmax(model)
196 detectors["EnergyBased"] = EnergyBased(model)
197 detectors["MaxLogit"] = MaxLogit(model)
198
199
200 print(f"> Fitting {len(detectors)} detectors")
201
202 for name, detector in detectors.items():
203     print(f"--> Fitting {name}")
204     detector.fit(loader_train, device=device)
207 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
208 results = []
209
210 with torch.no_grad():
211     for detector_name, detector in detectors.items():
212         print(f"> Evaluating {detector_name}")
213         for dataset_name, loader in datasets.items():
214             print(f"--> {dataset_name}")
215             metrics = OODMetrics()
216             for x, y in loader:
217                 metrics.update(detector(x.to(device)), y.to(device))
218
219             r = {"Detector": detector_name, "Dataset": dataset_name}
220
221             r.update(metrics.compute())
222             results.append(r)
227 df = pd.DataFrame(results)
228 mean_scores = df.groupby("Detector").mean() * 100
229 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))

Gallery generated by Sphinx-Gallery