Source code for pytorch_ood.dataset.txt.wmt16

"""
Much of the code is taken from the baseline-implementation:
https://github.com/hendrycks/outlier-exposure/blob/master/NLP_classification/wmt16/
"""
import logging
import os
from typing import Tuple

from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url

log = logging.getLogger(__name__)


[docs] class WMT16Sentences(Dataset): """ WMT16 sentences, as used by Hendrycks et al. Usually used os OOD data, labels are -1 by default. """ url = "https://raw.githubusercontent.com/hendrycks/outlier-exposure/master/NLP_classification/wmt16/wmt16_sentences" md5 = "6dff65f45ac112c150b8a2cc30509b03" filename = "wmt16_sentences" def __init__(self, root, transform=None, target_transform=None, download=True): super(Dataset, self).__init__() self.root = os.path.expanduser(root) self.transforms = transform self.target_transform = target_transform if download: self._download() self._data = self._load_data() def _download(self): if self._check_integrity(): log.info("Files already downloaded and verified") return download_url(self.url, self.root, self.filename, self.md5) def _load_data(self) -> Tuple: filename = os.path.join(self.root, self.filename) x = [] with open(filename, "r") as f: for line in f: words = line.split() text = " ".join(word for word in words) x.append(text) return x def _check_integrity(self): try: self._load_data() except Exception as e: # log.exception(e) return False return True def __getitem__(self, index): x = self._data[index] y = -1 if self.target_transform: y = self.target_transform(y) if self.transforms: x = self.transforms(x) return x, y def __len__(self): return len(self._data)