Source code for pytorch_ood.benchmark.base

from abc import ABC, abstractmethod
from typing import Dict, List

from torch.utils.data import Dataset

from pytorch_ood.api import Detector


[docs] class Benchmark(ABC): """ Base class for Benchmarks """
[docs] @abstractmethod def train_set(self) -> Dataset: """ Training dataset """
[docs] @abstractmethod def test_sets(self, known=True, unknown=True) -> List[Dataset]: """ List of the different test datasets. If known and unknown are true, each dataset contains ID and OOD data. :param known: include ID :param unknown: include OOD """ pass
[docs] @abstractmethod def evaluate(self, detector: Detector, *args, **kwargs) -> List[Dict]: """ Evaluates the given detector on all datasets and returns a list with the results """ pass