Source code for pytorch_ood.utils.transforms

"""

..  autoclass:: pytorch_ood.utils.ToUnknown
    :members:

..  autoclass:: pytorch_ood.utils.ToRGB
    :members:

..  autoclass:: pytorch_ood.utils.TargetMapping
    :members:

"""

from typing import Set, Callable, Union

import os
import random
from os.path import join

import numpy as np
from PIL import Image, ImageDraw
import torch
from collections import defaultdict
import json
from typing import List, Tuple

from torch import Tensor
from torchvision.datasets.utils import download_and_extract_archive


[docs] class ToUnknown(object): """ Callable that returns a negative number, used in pipelines to mark specific datasets as OOD or unknown. """ def __init__(self): pass def __call__(self, y): return -1
[docs] class ToRGB(object): """ Convert Image to RGB, if it is not already. """ def __call__(self, x): try: return x.convert("RGB") except Exception as e: return x
[docs] class TargetMapping(object): """ Maps ID (a.k.a. known) classes to labels :math:`\\in [0,n]`, and OOD (a.k.a. unknown) classes to labels in :math:`[-\\infty, -1]`. This is required for open set simulations. **Example:** If we split up a dataset so that the classes 2,3,4,9 are considered *known* or *ID*, these class labels have to be remapped to 0,1,2,3 to be able to train using cross entropy with 1-of-K-vectors. All other classes have to be mapped to values :math:`<0` to be marked as OOD. """ def __init__(self, known: Set, unknown: Set): self._map = dict() self._map.update({clazz: index for index, clazz in enumerate(set(known))}) # mapping train_out classes to < 0 self._map.update({clazz: (-clazz) for index, clazz in enumerate(set(unknown))}) def __call__(self, target): if isinstance(target, torch.Tensor): return self._map.get(target.item(), -1) return self._map.get(target, -1) def __getitem__(self, item): if isinstance(item, torch.Tensor): return self._map[item.item()] return self._map[item] def items(self): return self._map.items() def __repr__(self): return str(self._map)