Source code for pytorch_ood.dataset.img.odin

"""
Datasets used for testing in ODIN

First used in:the `ODIN paper<https://github.com/facebookresearch/odin>`__.
"""

import logging
import os
from typing import Any, Callable, Optional, Tuple

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, download_and_extract_archive

log = logging.getLogger(__name__)


[docs] class TinyImageNetCrop(VisionDataset): """ Cropped version of the TinyImageNet, often used as OOD data. :see Paper: `ArXiv <https://arxiv.org/abs/1706.02690>`__ """ base_folder = "Imagenet/test/" url = "https://www.dropbox.com/s/raw/avgm2u562itwpkl/Imagenet.tar.gz" filename = "Imagenet.tar.gz" tgz_md5 = "7c0827e4246c3718a5ee076e999e52e5" def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super(TinyImageNetCrop, self).__init__( root, transform=transform, target_transform=target_transform ) if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) self.basedir = os.path.join(self.root, self.base_folder) self.files = os.listdir(self.basedir) def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ file, target = self.files[index], -1 # doing this so that it is consistent with all other datasets # to return a PIL Image path = os.path.join(self.root, self.base_folder, file) img = Image.open(path).convert("RGB") if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.files) def _check_integrity(self) -> bool: root = self.root fpath = os.path.join(root, self.filename) return check_integrity(fpath, self.tgz_md5) def download(self) -> None: if self._check_integrity(): log.debug("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)
[docs] class TinyImageNetResize(TinyImageNetCrop): """ Resized version of the TinyImageNet, often used as OOD data. :see Paper: `ArXiv <https://arxiv.org/abs/1706.02690>`__ """ base_folder = "Imagenet_resize/Imagenet_resize/" url = "https://www.dropbox.com/s/raw/kp3my3412u5k9rl/Imagenet_resize.tar.gz" filename = "Imagenet_resize.tar.gz" tgz_md5 = "0f9ff11d45babf2eff5fe12281d1ac31" def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super(TinyImageNetResize, self).__init__( root, transform=transform, target_transform=target_transform, download=download, )
[docs] class LSUNCrop(TinyImageNetCrop): """ Cropped version of the LSUN, often used as OOD data. :see Paper: `ArXiv <https://arxiv.org/abs/1706.02690>`__ """ base_folder = "LSUN/test/" url = "https://www.dropbox.com/s/raw/fhtsw1m3qxlwj6h/LSUN.tar.gz" filename = "LSUN.tar.gz" tgz_md5 = "458a0a0ab8e5f1cb4516d7400568e460" def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super(LSUNCrop, self).__init__( root, transform=transform, target_transform=target_transform, download=download, )
[docs] class LSUNResize(TinyImageNetCrop): """ Resized version of the LSUN dataset, often used as OOD data. :see Paper: `ArXiv <https://arxiv.org/abs/1706.02690>`__ """ base_folder = "LSUN_resize/LSUN_resize" url = "https://www.dropbox.com/s/raw/moqh2wh8696c3yl/LSUN_resize.tar.gz" filename = "LSUN_resize.tar.gz" tgz_md5 = "278b7b31c8cb7e804a1465a8ce03a2dc" def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super(LSUNResize, self).__init__( root, transform=transform, target_transform=target_transform, download=download, )