import logging
import os
from glob import glob as glb
from os.path import join
from typing import Any, Callable, Optional, Tuple
import numpy as np
import torch
from PIL import Image
from .base import ImageDatasetBase
log = logging.getLogger(__name__)
[docs]
class MVTechAD(ImageDatasetBase):
"""
MVTec AD is a dataset for benchmarking anomaly detection methods with a focus on industrial inspection.
The dataset provides segmentation masks for anomalies.
.. image:: https://www.mvtec.com/fileadmin/Redaktion/mvtec.com/company/research/datasets/dataset_overview_large.png
:width: 800px
:alt: MVTech Anomaly Detection Dataset
:align: center
:see Paper: https://link.springer.com/content/pdf/10.1007/s11263-020-01400-4.pdf
:see Download: https://www.mvtec.com/company/research/datasets/mvtec-ad/
Split must be one of ``train`` or ``test``.
Subset classes can be one of ``bottle``, ``cable``, ``capsule``, ``carpet``,
``grid``, ``hazelnut``, ``leather``, ``metal_nut``, ``pill``, ``screw``, ``tile``,
``toothbrush``, ``transistor``, ``wood`` and ``zipper``.
"""
splits = ["train", "test"]
subsets = [
"bottle",
"cable",
"capsule",
"carpet",
"grid",
"hazelnut",
"leather",
"metal_nut",
"pill",
"screw",
"tile",
"toothbrush",
"transistor",
"wood",
"zipper",
]
url = "https://www.mydrive.ch/shares/38536/3830184030e49fe74747669442f0f282/download/420938113-1629952094/mvtec_anomaly_detection.tar.xz"
filename = "mvtec_anomaly_detection.tar.xz"
tgz_md5s = "4b34b33045869ee6d424616cd3a65da3"
def __init__(
self,
root: str,
split: str,
subset: Optional[str] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
download: bool = False,
) -> None:
"""
:param root: root directory
:param split: split directory
:param subset: subset class to use
:param transform: transformations to apply to image
:param target_transform: transformation to apply to target masks
:param download: set to true to automatically download the dataset
"""
super(ImageDatasetBase, self).__init__(
join(root, "mvtech-ad"),
transform=transform,
target_transform=target_transform,
)
if split not in self.splits:
raise ValueError(f"Invalid split: {split}")
else:
self.split = split
if download:
self.download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted." + " You can use download=True to download it"
)
if subset:
if subset in self.subsets:
self.subset = subset
else:
raise ValueError(f"Invalid subset: {subset}, possible values: {self.subset}")
else:
self.subset = None
self.files = None
self.labels = None
self.load()
def _get_subset_files(self, subset_dir):
"""
Returns two lists with filenames to images and corresponding segmentation masks.
For instances without anomalies, the segmentation mask file will be None.
"""
ls = []
fs = []
defect_dirs = os.listdir(join(subset_dir, self.split))
for defect_dir in defect_dirs:
files = glb(join(subset_dir, self.split, defect_dir, "*.png"))
files.sort() # sort, since glob does not guarantee ordering
if defect_dir == "good":
labels = [None] * len(files)
else:
labels = glb(join(subset_dir, "ground_truth", defect_dir, "*_mask.png"))
labels.sort()
ls += labels
fs += files
return fs, ls
def _get_all_files(self, root):
files = list()
labels = list()
# Iterate over all the the subsets
for subset in self.subsets:
# Create full path
subset_dir = join(root, subset)
if os.path.isdir(subset_dir):
# Iterate over the folders in subset
img_paths, mask_paths = self._get_subset_files(subset_dir)
files += img_paths
labels += mask_paths
return files, labels
def load(self):
if self.subset:
self.files, self.labels = self._get_subset_files(join(self.root, self.subset))
else:
self.files, self.labels = self._get_all_files(self.root)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
:param index: index
:returns: (image, target) where target is the segmentation mask
"""
img_path = self.files[index]
target = self.labels[index]
img = Image.open(img_path)
if target is None:
target = torch.zeros(size=img.size)
else:
target = -1 * torch.tensor(np.array(Image.open(target)))
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target