.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/text/newsgroups_oe.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_text_newsgroups_oe.py: Newsgroups Outlier Exposure ============================== Benchmark code for Newsgroups 20, trained with Outlier Exposure on the WikiText2 dataset. We test the detectors against three different Text dataset and calculate the mean performance. Uses GRU model from the OOD detection baseline paper. The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available. +-------------+-------+---------+----------+----------+ | Detector | AUROC | AUPR-IN | AUPR-OUT | FPR95TPR | +=============+=======+=========+==========+==========+ | ViM | 57.87 | 67.14 | 65.42 | 53.18 | +-------------+-------+---------+----------+----------+ | Mahalanobis | 63.27 | 68.39 | 69.40 | 50.52 | +-------------+-------+---------+----------+----------+ | KLMatching | 92.92 | 91.70 | 93.17 | 21.32 | +-------------+-------+---------+----------+----------+ | MaxSoftmax | 93.55 | 92.27 | 94.50 | 20.17 | +-------------+-------+---------+----------+----------+ | Entropy | 94.33 | 93.05 | 95.10 | 19.76 | +-------------+-------+---------+----------+----------+ | EnergyBased | 94.63 | 93.14 | 95.67 | 16.90 | +-------------+-------+---------+----------+----------+ | MaxLogit | 94.66 | 93.14 | 95.66 | 17.18 | +-------------+-------+---------+----------+----------+ .. GENERATED FROM PYTHON SOURCE LINES 33-69 .. code-block:: Python :lineno-start: 34 import pandas as pd import torch from torch.utils.data import DataLoader from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator from tqdm import tqdm from pytorch_ood.dataset.txt import ( Multi30k, NewsGroup20, Reuters52, WikiText2, WMT16Sentences, ) from pytorch_ood.detector import ( ODIN, EnergyBased, Entropy, KLMatching, Mahalanobis, MaxLogit, MaxSoftmax, ViM, ) from pytorch_ood.loss import OutlierExposureLoss from pytorch_ood.model import GRUClassifier from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known fix_random_seed(123) n_epochs = 5 lr = 0.001 device = "cuda:0" root = "data" .. GENERATED FROM PYTHON SOURCE LINES 70-71 download datasets .. GENERATED FROM PYTHON SOURCE LINES 71-89 .. code-block:: Python :lineno-start: 71 train_dataset_in = NewsGroup20(root, train=True, download=True) tokenizer = get_tokenizer("basic_english") def yield_tokens(data_iter): for text, _ in data_iter: yield tokenizer(text) vocab = build_vocab_from_iterator(yield_tokens(train_dataset_in)) vocab.set_default_index(0) def prep(x): return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64) .. GENERATED FROM PYTHON SOURCE LINES 90-96 .. code-block:: Python :lineno-start: 90 train_dataset_in = NewsGroup20(root, train=True, transform=prep) dataset_in_test = NewsGroup20(root, train=False, transform=prep) train_ood_dataset = WikiText2( root, split="train", download=True, transform=prep, target_transform=ToUnknown() ) .. GENERATED FROM PYTHON SOURCE LINES 97-98 Add padding, etc. .. GENERATED FROM PYTHON SOURCE LINES 98-121 .. code-block:: Python :lineno-start: 100 def collate_batch(batch): texts = [i[0] for i in batch] labels = torch.tensor([i[1] for i in batch], dtype=torch.int64) t_lengths = torch.tensor([len(t) for t in texts]) max_t_length = torch.max(t_lengths) padded = [] for text in texts: t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text]) padded.append(t) return torch.stack(padded, dim=0), labels loader_train = DataLoader( train_dataset_in + train_ood_dataset, batch_size=20, shuffle=True, collate_fn=collate_batch, ) loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch) .. GENERATED FROM PYTHON SOURCE LINES 122-131 .. code-block:: Python :lineno-start: 122 print("STAGE 1: Train Model") model = GRUClassifier(num_classes=20, n_vocab=len(vocab)) optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs) criterion = OutlierExposureLoss() model.to(device) .. GENERATED FROM PYTHON SOURCE LINES 132-178 .. code-block:: Python :lineno-start: 132 for epoch in range(n_epochs): print(f"Epoch {epoch}") model.train() loss_ema = None correct = 0 total = 0 bar = tqdm(loader_train) for n, batch in enumerate(bar): inputs, labels = batch inputs = inputs.to(device) labels = labels.to(device) logits = model(inputs) loss = criterion(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01 pred = logits.max(dim=1).indices correct += pred[is_known(labels)].eq(labels[is_known(labels)]).sum().data.cpu().item() total += is_known(labels).sum() bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}") with torch.no_grad(): model.eval() correct = 0 total = 0 for n, batch in enumerate(loader_in_test): inputs, labels = batch inputs = inputs.cuda() labels = labels.cuda() logits = model(inputs) pred = logits.max(dim=1).indices correct += pred.eq(labels).sum().data.cpu().item() total += pred.shape[0] print(f"Test Accuracy: {correct / total:.2%}") .. GENERATED FROM PYTHON SOURCE LINES 179-193 .. code-block:: Python :lineno-start: 180 ood_datasets = [Reuters52, Multi30k, WMT16Sentences] datasets = {} for ood_dataset in ood_datasets: dataset_out_test = ood_dataset( root="data", transform=prep, target_transform=ToUnknown(), download=True ) test_loader = DataLoader( dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch ) datasets[ood_dataset.__name__] = test_loader .. GENERATED FROM PYTHON SOURCE LINES 194-195 Fit detectors to training data (some require this, some do not) .. GENERATED FROM PYTHON SOURCE LINES 195-214 .. code-block:: Python :lineno-start: 195 print("STAGE 2: Creating OOD Detectors") detectors = {} detectors["Entropy"] = Entropy(model) detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias) detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0) detectors["KLMatching"] = KLMatching(model) detectors["MaxSoftmax"] = MaxSoftmax(model) detectors["EnergyBased"] = EnergyBased(model) detectors["MaxLogit"] = MaxLogit(model) print(f"> Fitting {len(detectors)} detectors") for name, detector in detectors.items(): print(f"--> Fitting {name}") detector.to(device) detector.fit(loader_train) .. GENERATED FROM PYTHON SOURCE LINES 215-233 .. code-block:: Python :lineno-start: 215 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.") results = [] with torch.no_grad(): for detector_name, detector in detectors.items(): print(f"> Evaluating {detector_name}") for dataset_name, loader in datasets.items(): print(f"--> {dataset_name}") metrics = OODMetrics() for x, y in loader: metrics.update(detector(x.to(device)), y.to(device)) r = {"Detector": detector_name, "Dataset": dataset_name} r.update(metrics.compute()) results.append(r) .. GENERATED FROM PYTHON SOURCE LINES 234-238 .. code-block:: Python :lineno-start: 235 df = pd.DataFrame(results) mean_scores = df.groupby("Detector").mean() * 100 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f")) .. _sphx_glr_download_auto_examples_text_newsgroups_oe.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: newsgroups_oe.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: newsgroups_oe.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: newsgroups_oe.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_