Source code for pytorch_ood.dataset.img.pixmix

import os
from os.path import join
from typing import Optional, Callable, Tuple, Any

from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, extract_archive

import logging


log = logging.getLogger(__name__)


class PixMixExampleDatasets(VisionDataset):
    google_drive_id = "1qC2gIUx9ARU7zhgI4IwGD3YcFhm8J4cA"
    filename = "fractals_and_fvis.tar"
    subdirs = {
        "fractals": "fractals/images/",
        "features": "first_layers_resized256_onevis/images/",
    }
    base_folder = "fractals_and_fvis"
    md5sum = "3619fb7e2c76130749d97913fdd3ab27"

    def __init__(
        self,
        root: str,
        subset: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        super(PixMixExampleDatasets, self).__init__(
            root, transform=transform, target_transform=target_transform
        )

        if subset not in self.subdirs.keys():
            raise ValueError(f"Invalid subset '{subset}'. Allowed: {list(self.subdirs.keys())}")

        if download:
            self.download()

        if not self._check_integrity():
            raise RuntimeError(
                "Dataset not found or corrupted." + " You can use download=True to download it"
            )

        self.basedir = join(self.root, self.base_folder, self.subdirs[subset])
        self.files = os.listdir(self.basedir)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        file, target = self.files[index], -1
        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        path = os.path.join(self.basedir, file)
        img = Image.open(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return img, target

    def __len__(self) -> int:
        return len(self.files)

    def _check_integrity(self) -> bool:
        fpath = os.path.join(self.root, self.filename)
        return check_integrity(fpath, self.md5sum)

    def download(self) -> None:
        if self._check_integrity():
            log.debug("Files already downloaded and verified")
            return

        try:
            import gdown

            archive = os.path.join(self.root, self.filename)
            if not gdown.download(id=self.google_drive_id, output=archive):
                raise Exception("File must be downloaded manually")

            log.info("Extracting {archive} to {self.root}")
            extract_archive(archive, self.root, remove_finished=False)

        except ImportError:
            raise ImportError("You have to install 'gdown' to use this dataset.")


[docs] class FeatureVisDataset(PixMixExampleDatasets): """ Dataset with Feature visualizations, as used in *PixMix: Dreamlike Pictures Comprehensively Improve Safety Measures*. :see Paper: `ArXiv <https://arxiv.org/abs/2112.05135>`__ """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download=False, ) -> None: super(FeatureVisDataset, self).__init__( root, subset="features", transform=transform, target_transform=target_transform, download=download, )
[docs] class FractalDataset(PixMixExampleDatasets): """ Dataset with Fractals, as used in *PixMix: Dreamlike Pictures Comprehensively Improve Safety Measures*. :see Paper: `ArXiv <https://arxiv.org/abs/2112.05135>`__ """ def __init__( self, root: str, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download=False, ) -> None: super(FractalDataset, self).__init__( root, subset="fractals", transform=transform, target_transform=target_transform, download=download, )