from abc import ABC, abstractmethod
from typing import TypeVar
from torch import Tensor
from torch.utils.data import DataLoader
Self = TypeVar("Self")
class RequiresFittingException(Exception):
"""
Raised when predict is called on a detector that has not been fitted.
"""
def __init__(self, msg="You have to call fit() before predict()"):
super(RequiresFittingException, self).__init__(msg)
class ModelNotSetException(ValueError):
"""
Raised when predict() is called but no model was given.
"""
def __init__(self, msg="When using predict(), model must not be None"):
super(ModelNotSetException, self).__init__(msg)
[docs]
class Detector(ABC):
"""
Abstract Base Class for an Out-of-Distribution Detector
"""
def __call__(self, *args, **kwargs) -> Tensor:
"""
Forwards to predict
"""
return self.predict(*args, **kwargs)
[docs]
@abstractmethod
def fit(self: Self, data_loader: DataLoader) -> Self:
"""
Fit the detector to a dataset. Some methods require this.
:param data_loader: dataset to fit on. This is usually the training dataset.
:raise ModelNotSetException: if model was not set
"""
raise NotImplementedError
[docs]
@abstractmethod
def fit_features(self: Self, x: Tensor, y: Tensor) -> Self:
"""
Fit the detector directly on features. Some methods require this.
:param x: training features to use for fitting.
:param y: corresponding class labels.
"""
raise NotImplementedError
[docs]
@abstractmethod
def predict(self, x: Tensor) -> Tensor:
"""
Calculates outlier scores. Inputs will be passed through the model.
:param x: batch of data
:return: outlier scores for points
:raise RequiresFitException: if detector has to be fitted to some data
:raise ModelNotSetException: if model was not set
"""
raise NotImplementedError
[docs]
@abstractmethod
def predict_features(self, x: Tensor) -> Tensor:
"""
Calculates outlier scores based on features.
:param x: batch of data
:return: outlier scores for points
:raise RequiresFitException: if detector has to be fitted to some data
"""
raise NotImplementedError