from abc import ABC, abstractmethod
from typing import Dict, List
from torch.utils.data import Dataset
from pytorch_ood.api import Detector
[docs]
class Benchmark(ABC):
"""
Base class for Benchmarks
"""
[docs]
@abstractmethod
def train_set(self) -> Dataset:
"""
Training dataset
"""
[docs]
@abstractmethod
def test_sets(self, known=True, unknown=True) -> List[Dataset]:
"""
List of the different test datasets.
If known and unknown are true, each dataset contains IN and OOD data.
:param known: include IN
:param unknown: include OOD
"""
pass
[docs]
@abstractmethod
def evaluate(self, detector: Detector, *args, **kwargs) -> List[Dict]:
"""
Evaluates the given detector on all datasets and returns a list with the results
"""
pass