Source code for pytorch_ood.utils.transforms

"""

..  autoclass:: pytorch_ood.utils.ToUnknown
    :members:

..  autoclass:: pytorch_ood.utils.ToRGB
    :members:

..  autoclass:: pytorch_ood.utils.TargetMapping
    :members:
"""
from typing import Set

import torch


[docs] class ToUnknown(object): """ Callable that returns a negative number, used in pipelines to mark specific datasets as OOD or unknown. """ def __init__(self): pass def __call__(self, y): return -1
[docs] class ToRGB(object): """ Convert Image to RGB, if it is not already. """ def __call__(self, x): try: return x.convert("RGB") except Exception as e: return x
[docs] class TargetMapping(object): """ Maps known classes to index in :math:`[0,n]`, unknown classes to values in :math:`[-\\infty, -1]`. Required for open set simulations. **Example:** If we split up a dataset so that the classes 2,3,4,9 are considered *known* or *IN*, these class labels have to be remapped to 0,1,2,3 to be able to train using cross entropy with 1-of-K-vectors. All other classes have to be mapped to values :math:`<0` to be marked as OOD. Target mappings have to be known at evaluation time. """ def __init__(self, known: Set, unknown: Set): self._map = dict() self._map.update({clazz: index for index, clazz in enumerate(set(known))}) # mapping train_out classes to < 0 self._map.update({clazz: (-clazz) for index, clazz in enumerate(set(unknown))}) def __call__(self, target): if isinstance(target, torch.Tensor): return self._map.get(target.item(), -1) return self._map.get(target, -1) def __getitem__(self, item): if isinstance(item, torch.Tensor): return self._map[item.item()] return self._map[item] def items(self): return self._map.items() def __repr__(self): return str(self._map)