Note
Go to the end to download the full example code.
Newsgroups Outlier Exposure
Benchmark code for Newsgroups 20, trained with Outlier Exposure on the WikiText2 dataset. We test the detectors against three different Text dataset and calculate the mean performance.
Uses GRU model from the OOD detection baseline paper.
The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available.
Detector |
AUROC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|
ViM |
57.87 |
67.14 |
65.42 |
53.18 |
Mahalanobis |
63.27 |
68.39 |
69.40 |
50.52 |
KLMatching |
92.92 |
91.70 |
93.17 |
21.32 |
MaxSoftmax |
93.55 |
92.27 |
94.50 |
20.17 |
Entropy |
94.33 |
93.05 |
95.10 |
19.76 |
EnergyBased |
94.63 |
93.14 |
95.67 |
16.90 |
MaxLogit |
94.66 |
93.14 |
95.66 |
17.18 |
33 import pandas as pd
34 import torch
35 from torch.utils.data import DataLoader
36 from torchtext.data.utils import get_tokenizer
37 from torchtext.vocab import build_vocab_from_iterator
38 from tqdm import tqdm
39
40 from pytorch_ood.dataset.txt import Multi30k, NewsGroup20, Reuters52, WikiText2, WMT16Sentences
41 from pytorch_ood.detector import (
42 ODIN,
43 EnergyBased,
44 Entropy,
45 KLMatching,
46 Mahalanobis,
47 MaxLogit,
48 MaxSoftmax,
49 ViM,
50 )
51 from pytorch_ood.loss import OutlierExposureLoss
52 from pytorch_ood.model import GRUClassifier
53 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known
54
55 fix_random_seed(123)
56
57 n_epochs = 5
58 lr = 0.001
59 device = "cuda:0"
60 root = "data"
download datasets
64 train_dataset_in = NewsGroup20(root, train=True, download=True)
65
66 tokenizer = get_tokenizer("basic_english")
67
68
69 def yield_tokens(data_iter):
70 for text, _ in data_iter:
71 yield tokenizer(text)
72
73
74 vocab = build_vocab_from_iterator(yield_tokens(train_dataset_in))
75 vocab.set_default_index(0)
76
77
78 def prep(x):
79 return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64)
83 train_dataset_in = NewsGroup20(root, train=True, transform=prep)
84 dataset_in_test = NewsGroup20(root, train=False, transform=prep)
85 train_ood_dataset = WikiText2(
86 root, split="train", download=True, transform=prep, target_transform=ToUnknown()
87 )
Add padding, etc.
93 def collate_batch(batch):
94 texts = [i[0] for i in batch]
95 labels = torch.tensor([i[1] for i in batch], dtype=torch.int64)
96 t_lengths = torch.tensor([len(t) for t in texts])
97 max_t_length = torch.max(t_lengths)
98
99 padded = []
100 for text in texts:
101 t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text])
102 padded.append(t)
103 return torch.stack(padded, dim=0), labels
104
105
106 loader_train = DataLoader(
107 train_dataset_in + train_ood_dataset,
108 batch_size=20,
109 shuffle=True,
110 collate_fn=collate_batch,
111 )
112 loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch)
115 print("STAGE 1: Train Model")
116 model = GRUClassifier(num_classes=20, n_vocab=len(vocab))
117
118 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
119 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs)
120 criterion = OutlierExposureLoss()
121
122 model.to(device)
125 for epoch in range(n_epochs):
126 print(f"Epoch {epoch}")
127
128 model.train()
129 loss_ema = None
130 correct = 0
131 total = 0
132
133 bar = tqdm(loader_train)
134
135 for n, batch in enumerate(bar):
136 inputs, labels = batch
137
138 inputs = inputs.to(device)
139 labels = labels.to(device)
140 logits = model(inputs)
141 loss = criterion(logits, labels)
142
143 optimizer.zero_grad()
144 loss.backward()
145 optimizer.step()
146
147 loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01
148
149 pred = logits.max(dim=1).indices
150 correct += pred[is_known(labels)].eq(labels[is_known(labels)]).sum().data.cpu().item()
151 total += is_known(labels).sum()
152
153 bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}")
154
155 with torch.no_grad():
156 model.eval()
157 correct = 0
158 total = 0
159 for n, batch in enumerate(loader_in_test):
160 inputs, labels = batch
161
162 inputs = inputs.cuda()
163 labels = labels.cuda()
164 logits = model(inputs)
165 pred = logits.max(dim=1).indices
166 correct += pred.eq(labels).sum().data.cpu().item()
167 total += pred.shape[0]
168
169 print(f"Test Accuracy: {correct / total:.2%}")
173 ood_datasets = [Reuters52, Multi30k, WMT16Sentences]
174
175 datasets = {}
176
177 for ood_dataset in ood_datasets:
178 dataset_out_test = ood_dataset(
179 root="data", transform=prep, target_transform=ToUnknown(), download=True
180 )
181 test_loader = DataLoader(
182 dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch
183 )
184 datasets[ood_dataset.__name__] = test_loader
Fit detectors to training data (some require this, some do not)
188 print("STAGE 2: Creating OOD Detectors")
189
190 detectors = {}
191 detectors["Entropy"] = Entropy(model)
192 detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias)
193 detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0)
194 detectors["KLMatching"] = KLMatching(model)
195 detectors["MaxSoftmax"] = MaxSoftmax(model)
196 detectors["EnergyBased"] = EnergyBased(model)
197 detectors["MaxLogit"] = MaxLogit(model)
198
199
200 print(f"> Fitting {len(detectors)} detectors")
201
202 for name, detector in detectors.items():
203 print(f"--> Fitting {name}")
204 detector.fit(loader_train, device=device)
207 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.")
208 results = []
209
210 with torch.no_grad():
211 for detector_name, detector in detectors.items():
212 print(f"> Evaluating {detector_name}")
213 for dataset_name, loader in datasets.items():
214 print(f"--> {dataset_name}")
215 metrics = OODMetrics()
216 for x, y in loader:
217 metrics.update(detector(x.to(device)), y.to(device))
218
219 r = {"Detector": detector_name, "Dataset": dataset_name}
220
221 r.update(metrics.compute())
222 results.append(r)
227 df = pd.DataFrame(results)
228 mean_scores = df.groupby("Detector").mean() * 100
229 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f"))