Source code for pytorch_ood.dataset.img.cifar

import logging
from os.path import join
from typing import Any, Callable, Optional, Tuple

import numpy as np
from PIL import Image

from .base import ImageDatasetBase

log = logging.getLogger(__name__)


[docs] class CIFAR10C(ImageDatasetBase): """ Corrupted version of the CIFAR10 from the paper *Benchmarking Neural Network Robustness to Common Corruptions and Perturbations.* :see Website: `Zenodo <https://zenodo.org/record/2535967>`__ :see Paper: `ArXiv <https://arxiv.org/abs/1903.12261>`__ """ subsets = [ "brightness", "contrast", "defocus_blur", "elastic_transform", "fog", "frost", "gaussian_blur", "gaussian_noise", "glass_blur", "impulse_noise", "jpeg_compression", "motion_blur", "pixelate", "saturate", "shot_noise", "snow", "spatter", "speckle_noise", "zoom_blur", ] base_folder = "CIFAR-10-C" url = "https://zenodo.org/record/2535967/files/CIFAR-10-C.tar" filename = "CIFAR-10-C.tar" md5hash = "56bf5dcef84df0e2308c6dcbcbbd8499" def __init__( self, root: str, subset: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ): super(CIFAR10C, self).__init__(root, transform, target_transform, download) self.subset = subset if subset not in self.subsets and subset != "all": raise ValueError(f"Unknown Subset: {subset}") if subset == "all": self.data = np.concatenate( [np.load(join(root, self.base_folder, f"{s}.npy")) for s in self.subsets] ) self.targets = np.concatenate( [np.load(join(root, self.base_folder, "labels.npy")) for s in self.subsets] ) else: self.data = np.load(join(root, self.base_folder, f"{subset}.npy")) self.targets = np.load(join(root, self.base_folder, "labels.npy")) def __len__(self): return len(self.data) def __getitem__(self, index: int) -> Tuple[Any, Any]: img = self.data[index] target = self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target
[docs] class CIFAR100C(CIFAR10C): """ Corrupted version of the CIFAR100 from the paper *Benchmarking Neural Network Robustness to Common Corruptions and Perturbations.* :see Website: `Zenodo <https://zenodo.org/record/3555552>`__ :see Paper: `ArXiv <https://arxiv.org/abs/1903.12261>`__ """ base_folder = "CIFAR-100-C/" url = "https://zenodo.org/record/3555552/files/CIFAR-100-C.tar" filename = "CIFAR-100-C.tar" md5hash = "11f0ed0f1191edbf9fa23466ae6021d3"