Source code for pytorch_ood.dataset.img.textures

import logging
import os
from os.path import join
from typing import Any, Callable, Optional, Tuple

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, download_and_extract_archive

log = logging.getLogger(__name__)


[docs] class Textures(VisionDataset): """ Textures dataset from the paper *Describing Textures in the Wild*, also known as DTD. Often used as OOD data. .. image :: https://production-media.paperswithcode.com/datasets/DTD-0000002377-abe5e400_AubcN36.jpg :width: 600px :alt: Textured Dataset :align: center :see Paper: `ArXiv <https://arxiv.org/abs/1311.3618v2>`__ :see Website: `Link <https://www.robots.ox.ac.uk/~vgg/data/dtd/>`__ """ base_folder = "dtd/images/" url = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" filename = "textures-r1_0_1.tar.gz" tgz_md5 = "fff73e5086ae6bdbea199a49dfb8a4c1" def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False, ) -> None: super(Textures, self).__init__( root, transform=transform, target_transform=target_transform ) if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) self.basedir = join(self.root, self.base_folder) self.files = [] for d in os.listdir(self.basedir): self.files.extend( [join(d, f) for f in os.listdir(join(self.basedir, d)) if not f.startswith(".")] ) log.info(f"Found {len(self.files)} texture files.") def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ file, target = self.files[index], -1 # doing this so that it is consistent with all other datasets # to return a PIL Image path = join(self.root, self.base_folder, file) img = Image.open(path) if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self) -> int: return len(self.files) def _check_integrity(self) -> bool: root = self.root fpath = os.path.join(root, self.filename) return check_integrity(fpath, self.tgz_md5) def download(self) -> None: if self._check_integrity(): log.debug("Files already downloaded and verified") return download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)