Source code for pytorch_ood.dataset.txt.wiki

"""
"""
import logging
import os
from typing import Tuple

from torch.utils.data import Dataset
from torchvision.datasets.utils import download_and_extract_archive

log = logging.getLogger(__name__)


[docs] class WikiText2(Dataset): """ Contains collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. Usually used os OOD (training) data, for example, for :class:`Outlier Exposure <pytorch_ood.loss.OutlierExposureLoss>`. Labels are -1 by default. Split can be one of ``train``, ``test`` and ``val``. :see Paper: `ArXiv <https://.org/abs/1609.07843>`__ """ url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" md5 = "542ccefacc6c27f945fb54453812b3cd" base_dir = "wikitext-2" filenames = { "train": "wiki.train.tokens", "test": "wiki.test.tokens", "val": "wiki.valid.tokens", } def __init__(self, root, split, transform=None, target_transform=None, download=True): if split not in list(self.filenames.keys()): raise ValueError(f"Invalid split: {split}") super(Dataset, self).__init__() self.root = os.path.expanduser(root) self.transforms = transform self.target_transform = target_transform self.split = split if download: self._download() self._data = self._load_data() def _download(self): if self._check_integrity(): log.info("Files already downloaded and verified") return download_and_extract_archive( url=self.url, download_root=self.root, extract_root=self.root, md5=self.md5 ) def _load_data(self) -> Tuple: filename = self.filenames[self.split] filename = os.path.join(self.root, self.base_dir, filename) x = [] with open(filename, "r", encoding="utf8") as f: for line in f: words = line.split() text = " ".join(word for word in words) x.append(text) return x def _check_integrity(self): try: self._load_data() except Exception as e: # log.exception(e) return False return True def __getitem__(self, index) -> Tuple: x = self._data[index] y = -1 if self.target_transform: y = self.target_transform(y) if self.transforms: x = self.transforms(x) return x, y def __len__(self): return len(self._data)
[docs] class WikiText103(WikiText2): """ Contains collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. Usually used os OOD (training) data, for example, for :class:`Outlier Exposure <pytorch_ood.loss.OutlierExposureLoss>`. Labels are -1 by default. Split can be one of ``train``, ``test`` and ``val``. :see Paper: `ArXiv <https://.org/abs/1609.07843>`__ """ url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip" md5 = "9ddaacaf6af0710eda8c456decff7832" base_dir = "wikitext-103" filenames = { "train": "wiki.train.tokens", "test": "wiki.test.tokens", "val": "wiki.valid.tokens", }