Source code for pytorch_ood.dataset.img.tinyimagenet

import logging
import os
from os.path import exists, join

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import download_and_extract_archive

log = logging.getLogger(__name__)


[docs] class TinyImageNet(VisionDataset): """ Small Version of the ImageNet with images of size :math:`64 \\times 64` from 200 classes used by Stanford. Each class has 500 images for training. .. image :: https://production-media.paperswithcode.com/datasets/Tiny_ImageNet-0000001404-a53923c3_XCrVSGm.jpg :width: 400px :alt: Textured Dataset :align: center This dataset is often used for training, but not included in Torchvision. :see Website: `Stanford <http://cs231n.stanford.edu/>`__ """ url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" dir_name = "tiny-imagenet-200" tgz_md5 = "90528d7ca1a48142e341f4ef8d21d0de" filename = "tiny-imagenet-200.zip" subsets = ["train", "val", "test"] def __init__( self, root, subset="train", download=False, transform=None, target_transform=None, ): """ :para subset: can be one of ``train``, ``val`` and ``test`` """ if subset not in self.subsets: raise ValueError(f"Invalid subset: {subset}. Possible values are {self.subsets}") super(TinyImageNet, self).__init__( root, target_transform=target_transform, transform=transform ) self.subset = subset if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) classes = os.listdir(join(self.root, self.dir_name, "train")) classes.sort() self.class_map = {c: n for n, c in enumerate(classes)} # : map class_names to integers self.basename = join(self.root, self.dir_name, self.subset) self.paths = [] self.labels = [] if subset == "train": for d in classes: p = join(self.basename, d, "images") files = [join(p, img) for img in os.listdir(p)] self.paths += files self.labels += [self.class_map[d]] * len(files) elif subset == "val": anno_file = join(self.basename, "val_annotations.txt") with open(anno_file, "r") as f: for line in f.readlines(): path, label, x, y, z, t = " ".join(line.split()).split() self.paths.append(join(self.basename, "images", path)) self.labels.append(self.class_map[label]) elif subset == "test": d = join(self.basename, "images") self.paths = [join(d, img) for img in os.listdir(d)] self.labels = [-1] * len(self.paths) def download(self): if self._check_integrity(): log.debug("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5) def _check_integrity(self): return exists(join(self.root, self.dir_name)) def __getitem__(self, index: int): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.paths[index], self.labels[index] img = Image.open(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.paths)