Note
Go to the end to download the full example code.
Newsgroups Outlier Exposure
Benchmark code for Newsgroups 20, trained with Outlier Exposure on the WikiText2 dataset. We test the detectors against three different Text dataset and calculate the mean performance.
Uses GRU model from the OOD detection baseline paper.
The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.
Detector |
AUROC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|
ViM |
57.87 |
67.14 |
65.42 |
53.18 |
Mahalanobis |
63.27 |
68.39 |
69.40 |
50.52 |
KLMatching |
92.92 |
91.70 |
93.17 |
21.32 |
MaxSoftmax |
93.55 |
92.27 |
94.50 |
20.17 |
Entropy |
94.33 |
93.05 |
95.10 |
19.76 |
EnergyBased |
94.63 |
93.14 |
95.67 |
16.90 |
MaxLogit |
94.66 |
93.14 |
95.66 |
17.18 |
34 import pandas as pd
35 import torch
36 from torch.utils.data import DataLoader
37 from torchtext.data.utils import get_tokenizer
38 from torchtext.vocab import build_vocab_from_iterator
39 from tqdm import tqdm
40
41 from pytorch_ood.dataset.txt import (
42 Multi30k,
43 NewsGroup20,
44 Reuters52,
45 WikiText2,
46 WMT16Sentences,
47 )
48 from pytorch_ood.detector import (
49 EnergyBased,
50 Entropy,
51 KLMatching,
52 Mahalanobis,
53 MaxLogit,
54 MaxSoftmax,
55 ViM,
56 )
57 from pytorch_ood.loss import OutlierExposureLoss
58 from pytorch_ood.model import GRUClassifier
59 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
60
61 fix_random_seed(123)
62
63 n_epochs = 5
64 lr = 0.001
65 device = "cuda:0"
66 root = "data"
download datasets
70 train_dataset_in = NewsGroup20(root, train=True, download=True)
71
72 tokenizer = get_tokenizer("basic_english")
73
74
75 def yield_tokens(data_iter):
76 for text, _ in data_iter:
77 yield tokenizer(text)
78
79
80 vocab = build_vocab_from_iterator(yield_tokens(train_dataset_in))
81 vocab.set_default_index(0)
82
83
84 def prep(x):
85 return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
89 train_dataset_in = NewsGroup20(root, train=True, transform=prep)
90 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
91 train_ood_dataset = WikiText2(
92 root, split="train", download=True, transform=prep, target_transform=ToUnknown()
93 )
Add padding, etc.
99 def collate_batch(batch):
100 texts = [i[0] for i in batch]
101 labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
102 t_lengths = torch.tensor([len(t) for t in texts])
103 max_t_length = torch.max(t_lengths)
104
105 padded = []
106 for text in texts:
107 t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
108 padded.append(t)
109 return torch.stack(padded, dim=0), labels
110
111
112 loader_train = DataLoader(
113 train_dataset_in + train_ood_dataset,
114 batch_size=20,
115 shuffle=True,
116 collate_fn=collate_batch,
117 )
118 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
121 print("STAGE 1: Train Model")
122 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
123
124 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
125 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
126 criterion = OutlierExposureLoss()
127
128 model.to(device)
131 for epoch in range(n_epochs):
132 print(f"Epoch {epoch}")
133
134 model.train()
135 loss_ema = None
136 correct = 0
137 total = 0
138
139 bar = tqdm(loader_train)
140
141 for n, batch in enumerate(bar):
142 inputs, labels = batch
143
144 inputs = inputs.to(device)
145 labels = labels.to(device)
146 logits = model(inputs)
147 loss = criterion(logits, labels)
148
149 optimizer.zero_grad()
150 loss.backward()
151 optimizer.step()
152
153 loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
154
155 pred = logits.max(dim=1).indices
156 correct += pred[is_known(labels)].eq(labels[is_known(labels)]).sum().data.cpu().item()
157 total += is_known(labels).sum()
158
159 bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
160
161 with torch.no_grad():
162 model.eval()
163 correct = 0
164 total = 0
165 for n, batch in enumerate(loader_in_test):
166 inputs, labels = batch
167
168 inputs = inputs.cuda()
169 labels = labels.cuda()
170 logits = model(inputs)
171 pred = logits.max(dim=1).indices
172 correct += pred.eq(labels).sum().data.cpu().item()
173 total += pred.shape[0]
174
175 print(f"Test Accuracy: {correct / total:.2%}")
179 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
180
181 datasets = {}
182
183 for ood_dataset in ood_datasets:
184 dataset_out_test = ood_dataset(
185 root="data", transform=prep, target_transform=ToUnknown(), download=True
186 )
187 test_loader = DataLoader(
188 dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
189 )
190 datasets[ood_dataset.__name__] = test_loader
Fit detectors to training data (some require this, some do not)
194 print("STAGE 2: Creating OOD Detectors")
195
196 detectors = {}
197 detectors["Entropy"] = Entropy(model)
198 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
199 detectors["Mahalanobis"] = Mahalanobis(model.features)
200 detectors["KLMatching"] = KLMatching(model)
201 detectors["MaxSoftmax"] = MaxSoftmax(model)
202 detectors["EnergyBased"] = EnergyBased(model)
203 detectors["MaxLogit"] = MaxLogit(model)
204
205
206 print(f"> Fitting {len(detectors)} detectors")
207
208 for name, detector in detectors.items():
209 print(f"--> Fitting {name}")
210 detector.to(device)
211 detector.fit(loader_train)
214 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
215 results = []
216
217 with torch.no_grad():
218 for detector_name, detector in detectors.items():
219 print(f"> Evaluating {detector_name}")
220 for dataset_name, loader in datasets.items():
221 print(f"--> {dataset_name}")
222 metrics = OODMetrics()
223 for x, y in loader:
224 metrics.update(detector(x.to(device)), y.to(device))
225
226 r = {"Detector": detector_name, "Dataset": dataset_name}
227
228 r.update(metrics.compute())
229 results.append(r)
234 df = pd.DataFrame(results)
235 mean_scores = df.groupby("Detector").mean() * 100
236 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))