Source code for pytorch_ood.dataset.img.mnistc

import logging
import os
from os.path import join
from typing import Any, Callable, Optional, Tuple

import numpy as np
from PIL import Image

from .base import ImageDatasetBase

log = logging.getLogger(__name__)


[docs] class MNISTC(ImageDatasetBase): """ MNIST-C is MNIST with corruptions for benchmarking OOD methods. Split can be one of ``train``, ``test`` and ``leftovers``. Subsets can be one of ``all``, ``brightness``, ``canny_edges``, ``dotted_line``, ``fog``, ``glass_blur``, ``identity``, ``impulse_noise``, ``motion_blur``, ``rotate``, ``scale``, ``shear``, ``shot_noise``, ``spatter``, ``stripe``, ``translate`` and ``zigzag``. :see Paper: `ArXiv <https://arxiv.org/pdf/1906.02337.pdf>`__ :see Download: `Zenodo <https://zenodo.org/record/3239543>`__ .. image:: https://media.arxiv-vanity.com/render-output/4755208/corruption_examples.png :width: 800px :alt: MNIST-C Dataset examples :align: center """ splits = ["train", "test", "leftovers"] subsets = [ "brightness", "canny_edges", "dotted_line", "fog", "glass_blur", "identity", "impulse_noise", "motion_blur", "rotate", "scale", "shear", "shot_noise", "spatter", "stripe", "translate", "zigzag", ] base_folders = ["mnist_c", "mnist_c_leftovers"] urls = [ "https://zenodo.org/record/3239543/files/mnist_c.zip", "https://zenodo.org/record/3239543/files/mnist_c_leftovers.zip", ] filenames = [ "mnist_c.zip", "mnist_c_leftovers.zip", ] tgz_md5s = [ "4b34b33045869ee6d424616cd3a65da3", "c365e9c25addd5c24454b19ac7101070", ] def __init__( self, root: str, subset: str, split: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super(ImageDatasetBase, self).__init__( root, transform=transform, target_transform=target_transform ) if subset not in self.subsets and subset != "all": raise ValueError() if split not in self.splits: raise ValueError() self.base_folder = join( root, self.base_folders[1] if split == "leftovers" else self.base_folders[0] ) self.url = self.urls[0] if split in ["train", "test"] else self.urls[1] self.filename = self.filenames[0] if split in ["train", "test"] else self.filenames[1] self.tgz_md5 = self.tgz_md5s[0] if split in ["train", "test"] else self.tgz_md5s[1] if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) self.subset = subset if split == "leftovers": # TODO pass if subset == "all": self.data = np.concatenate( [np.load(join(self.base_folder, s, f"{split}_images.npy")) for s in self.subsets] ) self.targets = np.concatenate( [np.load(join(self.base_folder, s, f"{split}_labels.npy")) for s in self.subsets] ) else: self.data = np.load(join(self.base_folder, subset, f"{split}_images.npy")) self.targets = np.load(join(self.base_folder, subset, f"{split}_labels.npy")) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img = self.data[index] target = self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.squeeze(), "L") if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.data)