Source code for pytorch_ood.dataset.img.sumnist

import logging
import os.path
from os.path import join

import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision.datasets.utils import check_integrity, download_url

log = logging.getLogger(__name__)


[docs] class SuMNIST(Dataset): """ The SuMNIST dataset comprises images with a size of :math:`56 \\times 56`, each containing 4 numbers from the MNIST dataset. In the training dataset, there are 60,000 normal instances where the numbers in the image sum to 20. However, the test set with 10,000 images, there are 8,500 normal instances and 1,500 anomalous instances for which the numbers do not sum to 20. The challenge is to detect these anomalies. Returns a tuple with ``(img, dict)`` where dict contains bounding boxes, labels, etc. :see Paper: `LNCS <https://link.springer.com/chapter/10.1007/978-3-031-40953-0_32>`__ :see Examples: `GitHub <https://github.com/kkirchheim/sumnist>`__ .. image:: https://github.com/kkirchheim/sumnist/blob/master/img/mnist-example.png?raw=true :width: 800px :alt: SuMNIST Dataset examples :align: center """ url = "https://files.kondas.de/sumnist/" files = { "b-test.npz": "85a544301eff979e252b8946e31fd795", "b-train.npz": "44b6208a8675df1a78c981b5ad8c4e50", "x-test.npz": "7db6727ec075cca1bb4dd0881087ac57", "x-train.npz": "c087f1c74a6f7ffcad9956be6f99cf10", "y-test.npz": "7239555b3d809657c06fbbc8da6f3e5e", "y-train.npz": "c6c94eb5ed7ebbe1c466abb14712e807", } base_dir = "sumnist" def __init__(self, root, train=True, transforms=None, download=False): """ :param root: where to store dataset :param train: set to `False` to use test set :param transforms: callable to apply to image and target dictionary :param download: set to `True` to download automatically """ self.root = join(root, SuMNIST.base_dir) self.transforms = transforms if download: self.download() if not self._check_integrity(): raise RuntimeError( "Dataset not found or corrupted. You can use download=True to download it" ) if train: with np.load(join(self.root, "x-train.npz")) as data: self.x = torch.tensor(data["arr_0"]) with np.load(join(self.root, "y-train.npz")) as data: self.y = torch.tensor(data["arr_0"]) with np.load(join(self.root, "b-train.npz")) as data: self.b = torch.tensor(data["arr_0"]) else: with np.load(join(self.root, "x-test.npz")) as data: self.x = torch.tensor(data["arr_0"]) with np.load(join(self.root, "y-test.npz")) as data: self.y = torch.tensor(data["arr_0"]) with np.load(join(self.root, "b-test.npz")) as data: self.b = torch.tensor(data["arr_0"]) def __len__(self): return len(self.x) def _check_integrity(self) -> bool: for file, hash in SuMNIST.files.items(): fpath = os.path.join(self.root, file) if not check_integrity(fpath, hash): return False return True def download(self): if self._check_integrity(): log.debug("Files already downloaded and verified") return for file, hash in SuMNIST.files.items(): url = SuMNIST.url + file download_url(url, self.root, md5=hash) def __getitem__(self, index): img = self.x[index] img = img.repeat(3, 1, 1) # To RGB bboxes = self.b[index] boxes = [] for box in bboxes: x_min, x_max, y_min, y_max = box boxes.append((x_min, y_min, x_max, y_max)) bboxes = boxes labels = self.y[index] boxes = torch.as_tensor(bboxes, dtype=torch.float32) labels = torch.as_tensor(labels, dtype=torch.int64) area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) target = {} target["boxes"] = boxes target["labels"] = labels target["image_id"] = torch.tensor([index]) target["area"] = area target["iscrowd"] = labels = torch.zeros((len(boxes),), dtype=torch.int64) target["anomaly"] = torch.tensor(-1 if labels.sum().item() != 20 else 0).long() if self.transforms is not None: img, target = self.transforms(img, target) return img, target