from typing import Dict, List
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import MNIST, SVHN, ImageNet
from torchvision.transforms import Compose
from pytorch_ood.api import Detector
from pytorch_ood.benchmark import Benchmark
from pytorch_ood.dataset.img import ImageNetO, OpenImagesO, Textures
from pytorch_ood.utils import OODMetrics, ToRGB, ToUnknown
[docs]
class ImageNet_OpenOOD(Benchmark):
"""
Aims to replicate the ImageNet benchmark proposed in
*OpenOOD: Benchmarking Generalized Out-of-Distribution Detection*.
:see Paper: `OpenOOD <https://openreview.net/pdf?id=gT6j4_tskUt>`__
Outlier datasets are
* ImageNet-O
* OpenImage-O
* Textures
* MNIST
* SVHN
* Texture
.. warning :: This currently does not reproduce the benchmark accurately, as it does not exclude images with
overlap with ImageNet and is missing the Species dataset.
"""
def __init__(self, root, image_net_root, transform):
"""
:param root: where to store datasets
:param image_net_root: root for the ImageNet dataset
:param transform: transform to apply to images
"""
self.transform = Compose([ToRGB(), transform])
self._train_in = None
self.image_net_root = image_net_root
self.test_in = ImageNet(image_net_root, transform=self.transform, split="val")
self.test_oods = [
ImageNetO(
root,
download=True,
transform=self.transform,
target_transform=ToUnknown(),
),
OpenImagesO(
root,
download=True,
transform=self.transform,
target_transform=ToUnknown(),
),
Textures(
root,
download=True,
transform=self.transform,
target_transform=ToUnknown(),
),
SVHN(
root,
split="test",
download=True,
transform=self.transform,
target_transform=ToUnknown(),
),
MNIST(
root,
root,
download=True,
transform=self.transform,
target_transform=ToUnknown(),
),
]
self.ood_names: List[str] = [] #: OOD Dataset names
self.ood_names = [type(d).__name__ for d in self.test_oods]
@property
def train_in(self):
# lazy loading only if needed
if not self._train_in:
self._train_in = ImageNet(self.image_net_root, split="train", transform=self.transform)
return self._train_in
[docs]
def train_set(self) -> Dataset:
"""
Training dataset
"""
return self.train_in
[docs]
def test_sets(self, known=True, unknown=True) -> List[Dataset]:
"""
List of the different test datasets.
If known and unknown are true, each dataset contains IN and OOD data.
:param known: include IN
:param unknown: include OOD
"""
if known and unknown:
return [self.test_in + other for other in self.test_oods]
if known and not unknown:
return [self.train_in]
if not known and unknown:
return self.test_oods
raise ValueError()
[docs]
def evaluate(
self, detector: Detector, loader_kwargs: Dict = None, device: str = "cpu"
) -> List[Dict]:
"""
Evaluates the given detector on all datasets and returns a list with the results
:param detector: the detector to evaluate
:param loader_kwargs: keyword arguments to give to the data loader
:param device: the device to move batches to
"""
if loader_kwargs is None:
loader_kwargs = {}
metrics = []
for name, dataset in zip(self.ood_names, self.test_sets()):
print(name)
loader = DataLoader(dataset=dataset, **loader_kwargs)
m = OODMetrics()
for x, y in loader:
m.update(detector(x.to(device)), y)
r = m.compute()
r.update({"Dataset": name})
metrics.append(r)
return metrics