Note
Go to the end to download the full example code.
Newsgroups Outlier Exposure
Benchmark code for Newsgroups 20, trained with Outlier Exposure on the WikiText2 dataset. We test the detectors against three different Text dataset and calculate the mean performance.
Uses GRU model from the OOD detection baseline paper.
The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.
Detector |
AUROC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|
ViM |
57.87 |
67.14 |
65.42 |
53.18 |
Mahalanobis |
63.27 |
68.39 |
69.40 |
50.52 |
KLMatching |
92.92 |
91.70 |
93.17 |
21.32 |
MaxSoftmax |
93.55 |
92.27 |
94.50 |
20.17 |
Entropy |
94.33 |
93.05 |
95.10 |
19.76 |
EnergyBased |
94.63 |
93.14 |
95.67 |
16.90 |
MaxLogit |
94.66 |
93.14 |
95.66 |
17.18 |
34 import pandas as pd
35 import torch
36 from torch.utils.data import DataLoader
37 from torchtext.data.utils import get_tokenizer
38 from torchtext.vocab import build_vocab_from_iterator
39 from tqdm import tqdm
40
41 from pytorch_ood.dataset.txt import (
42 Multi30k,
43 NewsGroup20,
44 Reuters52,
45 WikiText2,
46 WMT16Sentences,
47 )
48 from pytorch_ood.detector import (
49 ODIN,
50 EnergyBased,
51 Entropy,
52 KLMatching,
53 Mahalanobis,
54 MaxLogit,
55 MaxSoftmax,
56 ViM,
57 )
58 from pytorch_ood.loss import OutlierExposureLoss
59 from pytorch_ood.model import GRUClassifier
60 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
61
62 fix_random_seed(123)
63
64 n_epochs = 5
65 lr = 0.001
66 device = "cuda:0"
67 root = "data"
download datasets
71 train_dataset_in = NewsGroup20(root, train=True, download=True)
72
73 tokenizer = get_tokenizer("basic_english")
74
75
76 def yield_tokens(data_iter):
77 for text, _ in data_iter:
78 yield tokenizer(text)
79
80
81 vocab = build_vocab_from_iterator(yield_tokens(train_dataset_in))
82 vocab.set_default_index(0)
83
84
85 def prep(x):
86 return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
90 train_dataset_in = NewsGroup20(root, train=True, transform=prep)
91 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
92 train_ood_dataset = WikiText2(
93 root, split="train", download=True, transform=prep, target_transform=ToUnknown()
94 )
Add padding, etc.
100 def collate_batch(batch):
101 texts = [i[0] for i in batch]
102 labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
103 t_lengths = torch.tensor([len(t) for t in texts])
104 max_t_length = torch.max(t_lengths)
105
106 padded = []
107 for text in texts:
108 t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
109 padded.append(t)
110 return torch.stack(padded, dim=0), labels
111
112
113 loader_train = DataLoader(
114 train_dataset_in + train_ood_dataset,
115 batch_size=20,
116 shuffle=True,
117 collate_fn=collate_batch,
118 )
119 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
122 print("STAGE 1: Train Model")
123 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
124
125 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
126 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
127 criterion = OutlierExposureLoss()
128
129 model.to(device)
132 for epoch in range(n_epochs):
133 print(f"Epoch {epoch}")
134
135 model.train()
136 loss_ema = None
137 correct = 0
138 total = 0
139
140 bar = tqdm(loader_train)
141
142 for n, batch in enumerate(bar):
143 inputs, labels = batch
144
145 inputs = inputs.to(device)
146 labels = labels.to(device)
147 logits = model(inputs)
148 loss = criterion(logits, labels)
149
150 optimizer.zero_grad()
151 loss.backward()
152 optimizer.step()
153
154 loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
155
156 pred = logits.max(dim=1).indices
157 correct += pred[is_known(labels)].eq(labels[is_known(labels)]).sum().data.cpu().item()
158 total += is_known(labels).sum()
159
160 bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
161
162 with torch.no_grad():
163 model.eval()
164 correct = 0
165 total = 0
166 for n, batch in enumerate(loader_in_test):
167 inputs, labels = batch
168
169 inputs = inputs.cuda()
170 labels = labels.cuda()
171 logits = model(inputs)
172 pred = logits.max(dim=1).indices
173 correct += pred.eq(labels).sum().data.cpu().item()
174 total += pred.shape[0]
175
176 print(f"Test Accuracy: {correct / total:.2%}")
180 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
181
182 datasets = {}
183
184 for ood_dataset in ood_datasets:
185 dataset_out_test = ood_dataset(
186 root="data", transform=prep, target_transform=ToUnknown(), download=True
187 )
188 test_loader = DataLoader(
189 dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
190 )
191 datasets[ood_dataset.__name__] = test_loader
Fit detectors to training data (some require this, some do not)
195 print("STAGE 2: Creating OOD Detectors")
196
197 detectors = {}
198 detectors["Entropy"] = Entropy(model)
199 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
200 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
201 detectors["KLMatching"] = KLMatching(model)
202 detectors["MaxSoftmax"] = MaxSoftmax(model)
203 detectors["EnergyBased"] = EnergyBased(model)
204 detectors["MaxLogit"] = MaxLogit(model)
205
206
207 print(f"> Fitting {len(detectors)} detectors")
208
209 for name, detector in detectors.items():
210 print(f"--> Fitting {name}")
211 detector.to(device)
212 detector.fit(loader_train)
215 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
216 results = []
217
218 with torch.no_grad():
219 for detector_name, detector in detectors.items():
220 print(f"> Evaluating {detector_name}")
221 for dataset_name, loader in datasets.items():
222 print(f"--> {dataset_name}")
223 metrics = OODMetrics()
224 for x, y in loader:
225 metrics.update(detector(x.to(device)), y.to(device))
226
227 r = {"Detector": detector_name, "Dataset": dataset_name}
228
229 r.update(metrics.compute())
230 results.append(r)
235 df = pd.DataFrame(results)
236 mean_scores = df.groupby("Detector").mean() * 100
237 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))