"""
Much of the code is taken from the baseline-implementation:
https://github.com/hendrycks/error-detection/blob/master/NLP/Categorization/Reuters52.ipynb
"""
import logging
import os
import re
from typing import Tuple
import numpy as np
from torch.utils.data import ConcatDataset, Dataset
from torchvision.datasets.utils import download_url
from .stop_words import stop_words
log = logging.getLogger(__name__)
[docs]
class Reuters52(Dataset):
"""
Stemmed version of the Reuters 52 dataset, as used by Hendrycks et al.
Contains 52 classes.
"""
train_url = "https://raw.githubusercontent.com/hendrycks/error-detection/master/NLP/Categorization/data/r52-train.txt"
test_url = "https://raw.githubusercontent.com/hendrycks/error-detection/master/NLP/Categorization/data/r52-test.txt"
test_md5 = "8a82cdf79e111df1bb23a9bbc48f6d25"
train_md5 = "6b1d32bd95e95c1c26cd592d3bdb8c0e"
train_filename = "r52-train-stemmed.txt"
test_filename = "r52-test-stemmed.txt"
class2index = {
"acq": 0,
"alum": 1,
"bop": 2,
"carcass": 3,
"cocoa": 4,
"coffee": 5,
"copper": 6,
"cotton": 7,
"cpi": 8,
"cpu": 9,
"crude": 10,
"dlr": 11,
"earn": 12,
"fuel": 13,
"gas": 14,
"gnp": 15,
"gold": 16,
"grain": 17,
"heat": 18,
"housing": 19,
"income": 20,
"instal-debt": 21,
"interest": 22,
"ipi": 23,
"iron-steel": 24,
"jet": 25,
"jobs": 26,
"lead": 27,
"lei": 28,
"livestock": 29,
"lumber": 30,
"meal-feed": 31,
"money-fx": 32,
"money-supply": 33,
"nat-gas": 34,
"nickel": 35,
"orange": 36,
"pet-chem": 37,
"platinum": 38,
"potato": 39,
"reserves": 40,
"retail": 41,
"rubber": 42,
"ship": 43,
"strategic-metal": 44,
"sugar": 45,
"tea": 46,
"tin": 47,
"trade": 48,
"veg-oil": 49,
"wpi": 50,
"zinc": 51,
}
def __init__(self, root, transform=None, target_transform=None, train=True, download=True):
super(Dataset, self).__init__()
self.root = os.path.expanduser(root)
self.transforms = transform
self.target_transform = target_transform
self.is_train = train
self._targets = []
self._analyzer = None
if download:
self._download()
self._labels, self._data = self._load_data()
# mapping class names to integers
for i, label in enumerate(self._labels):
for clazz, index in self.class2index.items():
if label.startswith(clazz):
self._targets.append(index)
self._targets = np.array(self._targets)
def _download(self):
if self._check_integrity():
log.info("Files already downloaded and verified")
return
if self.is_train:
filename = self.train_filename
md5 = self.train_md5
url = self.train_url
else:
filename = self.test_filename
md5 = self.test_md5
url = self.test_url
download_url(url, self.root, filename, md5)
def _load_data(self) -> Tuple:
if self.is_train:
filename = self.train_filename
else:
filename = self.test_filename
filename = os.path.join(self.root, filename)
x, targets = [], []
with open(filename, "r") as f:
for line in f:
words = line.split()
text = " ".join(word for word in words[2:] if word not in stop_words)
x.append(text)
targets.append(".".join(words[0:2]))
return targets, x
def _check_integrity(self):
try:
self._load_data()
except Exception as e:
# log.exception(e)
return False
return True
def __getitem__(self, index):
x = self._data[index]
y = self._targets[index]
if self.target_transform:
y = self.target_transform(y)
if self.transforms:
x = self.transforms(x)
return x, y
def __len__(self):
return len(self._data)
[docs]
class Reuters8(Reuters52):
"""
Stemmed version of the Reuters 8 dataset, as used by Hendrycks et al.
Contains 8 classes.
"""
train_url = "https://raw.githubusercontent.com/hendrycks/error-detection/master/NLP/Categorization/data/r8-train.txt"
test_url = "https://raw.githubusercontent.com/hendrycks/error-detection/master/NLP/Categorization/data/r8-test.txt"
test_md5 = "7a54d01272de570d13d0c58cf8aa3c8d"
train_md5 = "c979a285f1c132c5be8d554b385f2c49"
train_filename = "r8-train-stemmed.txt"
test_filename = "r8-test-stemmed.txt"
def __init__(self, root, transform=None, target_transform=None, train=True, download=True):
super(Reuters52, self).__init__()
self.root = os.path.expanduser(root)
self.transforms = transform
self.target_transform = target_transform
self.is_train = train
self._analyzer = None
if download:
self._download()
self._targets, self._data = self._load_data()
def _load_data(self) -> Tuple:
if self.is_train:
filename = self.train_filename
else:
filename = self.test_filename
filename = os.path.join(self.root, filename)
x, targets = [], []
with open(filename, "r") as f:
for line in f:
line = re.sub(r"\W+", " ", line).strip()
x.append(line[1:])
x[-1] = " ".join(word for word in x[-1].split() if word not in stop_words)
targets.append(line[0])
return np.array(targets, dtype=int), x