.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/text/newsgroups_oe.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_text_newsgroups_oe.py: Newsgroups Outlier Exposure ============================== Benchmark code for Newsgroups 20, trained with Outlier Exposure on the WikiText2 dataset. We test the detectors against three different Text dataset and calculate the mean performance. Uses GRU model from the OOD detection baseline paper. The original results can not be reproduced, as the dictionaries (word-to-token-mappings) are not available. +-------------+-------+---------+----------+----------+ | Detector | AUROC | AUPR-IN | AUPR-OUT | FPR95TPR | +=============+=======+=========+==========+==========+ | ViM | 57.87 | 67.14 | 65.42 | 53.18 | +-------------+-------+---------+----------+----------+ | Mahalanobis | 63.27 | 68.39 | 69.40 | 50.52 | +-------------+-------+---------+----------+----------+ | KLMatching | 92.92 | 91.70 | 93.17 | 21.32 | +-------------+-------+---------+----------+----------+ | MaxSoftmax | 93.55 | 92.27 | 94.50 | 20.17 | +-------------+-------+---------+----------+----------+ | Entropy | 94.33 | 93.05 | 95.10 | 19.76 | +-------------+-------+---------+----------+----------+ | EnergyBased | 94.63 | 93.14 | 95.67 | 16.90 | +-------------+-------+---------+----------+----------+ | MaxLogit | 94.66 | 93.14 | 95.66 | 17.18 | +-------------+-------+---------+----------+----------+ .. GENERATED FROM PYTHON SOURCE LINES 33-62 .. code-block:: Python :lineno-start: 33 import pandas as pd import torch from torch.utils.data import DataLoader from torchtext.data.utils import get_tokenizer from torchtext.vocab import build_vocab_from_iterator from tqdm import tqdm from pytorch_ood.dataset.txt import Multi30k, NewsGroup20, Reuters52, WikiText2, WMT16Sentences from pytorch_ood.detector import ( ODIN, EnergyBased, Entropy, KLMatching, Mahalanobis, MaxLogit, MaxSoftmax, ViM, ) from pytorch_ood.loss import OutlierExposureLoss from pytorch_ood.model import GRUClassifier from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed, is_known fix_random_seed(123) n_epochs = 5 lr = 0.001 device = "cuda:0" root = "data" .. GENERATED FROM PYTHON SOURCE LINES 63-64 download datasets .. GENERATED FROM PYTHON SOURCE LINES 64-82 .. code-block:: Python :lineno-start: 64 train_dataset_in = NewsGroup20(root, train=True, download=True) tokenizer = get_tokenizer("basic_english") def yield_tokens(data_iter): for text, _ in data_iter: yield tokenizer(text) vocab = build_vocab_from_iterator(yield_tokens(train_dataset_in)) vocab.set_default_index(0) def prep(x): return torch.tensor([vocab[v] for v in tokenizer(x)], dtype=torch.int64) .. GENERATED FROM PYTHON SOURCE LINES 83-89 .. code-block:: Python :lineno-start: 83 train_dataset_in = NewsGroup20(root, train=True, transform=prep) dataset_in_test = NewsGroup20(root, train=False, transform=prep) train_ood_dataset = WikiText2( root, split="train", download=True, transform=prep, target_transform=ToUnknown() ) .. GENERATED FROM PYTHON SOURCE LINES 90-91 Add padding, etc. .. GENERATED FROM PYTHON SOURCE LINES 91-114 .. code-block:: Python :lineno-start: 93 def collate_batch(batch): texts = [i[0] for i in batch] labels = torch.tensor([i[1] for i in batch], dtype=torch.int64) t_lengths = torch.tensor([len(t) for t in texts]) max_t_length = torch.max(t_lengths) padded = [] for text in texts: t = torch.cat([torch.zeros(max_t_length - len(text), dtype=torch.long), text]) padded.append(t) return torch.stack(padded, dim=0), labels loader_train = DataLoader( train_dataset_in + train_ood_dataset, batch_size=20, shuffle=True, collate_fn=collate_batch, ) loader_in_test = DataLoader(dataset_in_test, batch_size=16, shuffle=True, collate_fn=collate_batch) .. GENERATED FROM PYTHON SOURCE LINES 115-124 .. code-block:: Python :lineno-start: 115 print("STAGE 1: Train Model") model = GRUClassifier(num_classes=20, n_vocab=len(vocab)) optimizer = torch.optim.Adam(model.parameters(), lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs) criterion = OutlierExposureLoss() model.to(device) .. GENERATED FROM PYTHON SOURCE LINES 125-171 .. code-block:: Python :lineno-start: 125 for epoch in range(n_epochs): print(f"Epoch {epoch}") model.train() loss_ema = None correct = 0 total = 0 bar = tqdm(loader_train) for n, batch in enumerate(bar): inputs, labels = batch inputs = inputs.to(device) labels = labels.to(device) logits = model(inputs) loss = criterion(logits, labels) optimizer.zero_grad() loss.backward() optimizer.step() loss_ema = loss.item() if not loss_ema else loss_ema * 0.99 + loss.item() * 0.01 pred = logits.max(dim=1).indices correct += pred[is_known(labels)].eq(labels[is_known(labels)]).sum().data.cpu().item() total += is_known(labels).sum() bar.set_postfix_str(f"loss: {loss:.2f} acc: {correct / total:.2%}") with torch.no_grad(): model.eval() correct = 0 total = 0 for n, batch in enumerate(loader_in_test): inputs, labels = batch inputs = inputs.cuda() labels = labels.cuda() logits = model(inputs) pred = logits.max(dim=1).indices correct += pred.eq(labels).sum().data.cpu().item() total += pred.shape[0] print(f"Test Accuracy: {correct / total:.2%}") .. GENERATED FROM PYTHON SOURCE LINES 172-186 .. code-block:: Python :lineno-start: 173 ood_datasets = [Reuters52, Multi30k, WMT16Sentences] datasets = {} for ood_dataset in ood_datasets: dataset_out_test = ood_dataset( root="data", transform=prep, target_transform=ToUnknown(), download=True ) test_loader = DataLoader( dataset_in_test + dataset_out_test, batch_size=16, collate_fn=collate_batch ) datasets[ood_dataset.__name__] = test_loader .. GENERATED FROM PYTHON SOURCE LINES 187-188 Fit detectors to training data (some require this, some do not) .. GENERATED FROM PYTHON SOURCE LINES 188-206 .. code-block:: Python :lineno-start: 188 print("STAGE 2: Creating OOD Detectors") detectors = {} detectors["Entropy"] = Entropy(model) detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias) detectors["Mahalanobis"] = Mahalanobis(model.features, eps=0.0) detectors["KLMatching"] = KLMatching(model) detectors["MaxSoftmax"] = MaxSoftmax(model) detectors["EnergyBased"] = EnergyBased(model) detectors["MaxLogit"] = MaxLogit(model) print(f"> Fitting {len(detectors)} detectors") for name, detector in detectors.items(): print(f"--> Fitting {name}") detector.fit(loader_train, device=device) .. GENERATED FROM PYTHON SOURCE LINES 207-225 .. code-block:: Python :lineno-start: 207 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.") results = [] with torch.no_grad(): for detector_name, detector in detectors.items(): print(f"> Evaluating {detector_name}") for dataset_name, loader in datasets.items(): print(f"--> {dataset_name}") metrics = OODMetrics() for x, y in loader: metrics.update(detector(x.to(device)), y.to(device)) r = {"Detector": detector_name, "Dataset": dataset_name} r.update(metrics.compute()) results.append(r) .. GENERATED FROM PYTHON SOURCE LINES 226-230 .. code-block:: Python :lineno-start: 227 df = pd.DataFrame(results) mean_scores = df.groupby("Detector").mean() * 100 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f")) .. _sphx_glr_download_auto_examples_text_newsgroups_oe.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: newsgroups_oe.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: newsgroups_oe.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_