Source code for pytorch_ood.dataset.img.smiyc

import logging
import os
from os.path import join
from typing import Any, Callable, List, Optional, Tuple

from PIL import Image
from torchvision.transforms.functional import to_tensor

from .base import ImageDatasetBase

log = logging.getLogger(__name__)


[docs] class SegmentMeIfYouCan(ImageDatasetBase): """ Benchmark Dataset for Anomaly Segmentation. From the paper *SegmentMeIfYouCan: A Benchmark for Anomaly Segmentation*. Contains two subsets: RoadAnomaly21 and RoadObstacle21 .. note:: Similar to Paper *Segment Every Out-of-Distribution Object* (`ArXiv <https://arxiv.org/pdf/2311.16516v3>`__, `Github <https://github.com/WenjieZhao1/S2M>`__) for ``RoadAnomaly21`` only **10** and for ``RoadObstacle21`` only **30** images are available. :see Paper: `ArXiv <https://arxiv.org/pdf/2104.14812>`__ :see Website: `Website <https://segmentmeifyoucan.com/datasets>`__ """ root_dir_name = "SMIYC" subset_list = ["RoadAnomaly21", "RoadObstacle21"] base_folders = { "RoadAnomaly21": "dataset_AnomalyTrack", "RoadObstacle21": "dataset_ObstacleTrack", } dataset_length = {"RoadAnomaly21": 10, "RoadObstacle21": 30} url_list = { "RoadAnomaly21": "https://zenodo.org/record/5270237/files/dataset_AnomalyTrack.zip", "RoadObstacle21": "https://zenodo.org/record/5281633/files/dataset_ObstacleTrack.zip", } filename_list = { "RoadAnomaly21": ( "dataset_AnomalyTrack.zip", "231bf79ed58924bcd33d9cbe22e61076", ), "RoadObstacle21": ( "dataset_ObstacleTrack.zip", "895fb36d18765482cc291f69e63d6da6", ), } VOID_LABEL = 1 #: void label, should be ignored during score calculation def __init__( self, root: str, subset: str, transform: Optional[Callable[[Tuple], Tuple]] = None, download: bool = False, ) -> None: """ :param root: root path for dataset :param subset: one of ``RoadAnomaly21``, ``RoadObstacle21`` :param transform: transformations to apply to images and masks, will get tuple as argument :param download: if dataset should be downloaded automatically """ root = join(root, self.root_dir_name) super(ImageDatasetBase, self).__init__(root, transform=transform) self.url = self.url_list[subset] self.filename, self.tgz_md5 = self.filename_list[subset] if download: self.download() if subset not in self.subset_list: raise ValueError(f"Invalid subset: {subset}") if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted." + " You can use download=True to download it" ) self.data_dir = join(root, self.base_folders[subset]) self.all_images, self.all_masks = self._get_file_list(self.data_dir, subset) assert self.dataset_length[subset] == len(self.all_images) def __len__(self) -> int: return len(self.all_images) def _get_file_list(self, root, subset) -> List[str]: """ Recursively get all files in the root directory :param root: root directory for the search """ current_files = [entry for entry in os.listdir(join(root, "labels_masks"))] all_images = [] all_masks = [] if subset == "RoadAnomaly21": for path in current_files: if path.endswith(".png") and "color" not in path: all_images.append(join(root, "images", path.split("_")[0] + ".jpg")) all_masks.append(join(root, "labels_masks", path)) if subset == "RoadObstacle21": for path in current_files: if path.endswith(".png") and "color" not in path: all_images.append( join( root, "images", path.split("_")[0] + "_" + path.split("_")[1] + ".webp", ) ) all_masks.append(join(root, "labels_masks", path)) assert len(all_images) == len(all_masks) if len(all_images) == 0: log.error("No images found in the directory") if len(all_masks) == 0: log.error("No masks found in the directory") if len(all_images) != len(all_masks): raise Exception( f"Number of images and masks do not match: num_img:{len(all_images)}, num_masks:{len(all_masks)}" ) return all_images, all_masks def __getitem__(self, index: int) -> Tuple[Any, Any]: """ :param index: index :returns: (image, target) where target is the annotation of the image. """ file, target = self.all_images[index], self.all_masks[index] # to return a PIL Image img = Image.open(file) target = to_tensor(Image.open(target)).squeeze(0) target = (target * 255).long() # void pixels to -10 target[target == 255] = -10 # -10 labels for ignore # all values above 0 are outliers target[target > 0] = -1 # negative labels for outliers # set void labels target[target == -10] = self.VOID_LABEL # void labels for ignore if self.transform is not None: img, target = self.transform(img, target) return img, target