Source code for pytorch_ood.dataset.img.imagenet

import logging
import os
from os.path import join
from typing import Callable, Optional

from PIL import Image
from torchvision.datasets import DatasetFolder
from torchvision.datasets.utils import check_integrity, download_and_extract_archive

from .base import ImageDatasetBase

log = logging.getLogger(__name__)


[docs] class ImageNetA(DatasetFolder): """ From the paper *Natural Adversarial Examples*. Contains images that classifiers should be able to classify :see Website: `GitHub <https://github.com/hendrycks/natural-adv-examples>`__ :see Paper: `ArXiv <https://arxiv.org/abs/1907.07174>`__ """ base_folder = "imagenet-a" url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar" filename = "imagenet-a.tar" tgz_md5 = "c3e55429088dc681f30d81f4726b6595" def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ): self.root = root if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) loader = Image.open super(ImageNetA, self).__init__( root=join(root, self.base_folder), loader=loader, is_valid_file=lambda x: x.endswith(".jpg") or x.endswith(".JPEG"), transform=transform, target_transform=target_transform, ) def _check_integrity(self) -> bool: return check_integrity(join(self.root, self.filename), self.tgz_md5) def download(self) -> None: if self._check_integrity(): log.debug("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
[docs] class ImageNetO(ImageNetA): """ From the paper *Natural Adversarial Examples*. Contains anomalies of unforeseen classes :see Website: `GitHub <https://github.com/hendrycks/natural-adv-examples>`__ :see Paper: `ArXiv <https://arxiv.org/abs/1907.07174>`__ """ base_folder = "imagenet-o" url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-o.tar" filename = "imagenet-o.tar" tgz_md5 = "86bd7a50c1c4074fb18fc5f219d6d50b"
[docs] class ImageNetR(ImageNetA): """ The ImageNet-R(endition) from the paper *The Many Faces of Robustness: A Critical Analysis of Out-of-Distribution Generalization* contains art, cartoons, deviantart, graffiti, embroidery, graphics, origami, paintings, patterns, plastic objects, plush objects, sculptures, sketches, tattoos, toys, and video game renditions of ImageNet classes. :see Website: `GitHub <https://github.com/hendrycks/imagenet-r>`__ :see Paper: `ArXiv <https://arxiv.org/abs/2006.16241>`__ """ base_folder = "imagenet-r" url = "https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar" filename = "imagenet-r.tar" tgz_md5 = "a61312130a589d0ca1a8fca1f2bd3337"
[docs] class ImageNetC(ImageDatasetBase): """ Corrupted version of the ImageNet from the paper *Benchmarking Neural Network Robustness to Common Corruptions and Perturbations.* It contains several subsets: * ``noise`` (21GB): gaussian_noise, shot_noise, and impulse_noise. * ``blur`` (7GB): defocus_blur, glass_blur, motion_blur, and zoom_blur. * ``weather`` (12GB): frost, snow, fog, and brightness. * ``digital`` (7GB): contrast, elastic_transform, pixelate, and jpeg_compression. * ``extra`` (15GB): speckle_noise, spatter, gaussian_blur, and saturate. :see Paper: `ArXiv <https://arxiv.org/abs/1903.12261v1>`__ """ subset_list = ["blur", "digital", "extra", "noise", "weather"] base_folder_list = [ "ImageNetC/blur/", "ImageNetC/digital/", "ImageNetC/extra/", "ImageNetC/noise/", "ImageNetC/weather/", ] url_list = [ "https://zenodo.org/record/2235448/files/blur.tar", "https://zenodo.org/record/2235448/files/digital.tar", "https://zenodo.org/record/2235448/files/extra.tar", "https://zenodo.org/record/2235448/files/noise.tar", "https://zenodo.org/record/2235448/files/weather.tar", ] filename_list = ["blur.tar", "digital.tar", "extra.tar", "noise.tar", "weather.tar"] tgz_md5_list = [ "2d8e81fdd8e07fef67b9334fa635e45c", "89157860d7b10d5797849337ca2e5c03", "d492dfba5fc162d8ec2c3cd8ee672984", "e80562d7f6c3f8834afb1ecf27252745", "33ffea4db4d93fe4a428c40a6ce0c25d", ] def __init__( self, root: str, subset: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: if subset not in self.subset_list: raise ValueError(f"Invalid subset: {subset}") self.base_folder = self.base_folder_list[self.subset_list.index(subset)] self.url = self.url_list[self.subset_list.index(subset)] self.filename = self.filename_list[self.subset_list.index(subset)] self.tgz_md5 = self.tgz_md5_list[self.subset_list.index(subset)] super(ImageDatasetBase, self).__init__( root, transform=transform, target_transform=target_transform ) if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) self.basedir = os.path.join(self.root, self.base_folder) self.files = os.listdir(self.basedir)