Source code for pytorch_ood.augment.img.pixmix

"""
Much of this code is directly taken or adapted from https://github.com/andyzoujm/pixmix/
which is licensed under MIT according to the gitlab repo, however, some of the files have an
apache 2.0 license header. Both should be compatible with our license.
"""

import logging
import os
from os.path import join
from typing import Any, Callable, Optional, Tuple

import numpy as np
import torch
import torchvision.transforms
import torchvision.transforms.functional as F
from PIL import Image
from torchvision.datasets import VisionDataset
from torchvision.datasets.utils import check_integrity, extract_archive
from torchvision.transforms.functional import to_tensor

log = logging.getLogger(__name__)


# NOTE: list of augmentations used originally
# augmentations_all = [
#     F.autocontrast, F.equalize, F.posterize, F.rotate, F.solarize, F.shear_x, F.shear_y,
#     F.translate_x, F.translate_y, F.color, F.contrast, F.brightness, F.sharpness
# ]


def get_ab(beta):
    if np.random.random() < 0.5:
        a = np.float32(np.random.beta(beta, 1))
        b = np.float32(np.random.beta(1, beta))
    else:
        a = 1 + np.float32(np.random.beta(1, beta))
        b = -np.float32(np.random.beta(1, beta))
    return a, b


def add(img1, img2, beta):
    a, b = get_ab(beta)
    img1, img2 = img1 * 2 - 1, img2 * 2 - 1
    out = a * img1 + b * img2
    return (out + 1) / 2


def multiply(img1, img2, beta):
    a, b = get_ab(beta)
    img1, img2 = img1 * 2, img2 * 2
    out = (img1**a) * (img2.clip(1e-37) ** b)
    return out / 2


[docs] class PixMixDataset(torch.utils.data.Dataset): """ Dataset wrapper to perform PixMix, from the paper *PixMix: Dreamlike Pictures Comprehensively Improve Safety Measures*. :see Paper: `ArXiv <https://arxiv.org/abs/2112.05135>`__ .. note :: Some of the augmentations primitives used in the paper are not yet implemented. .. image:: https://github.com/andyzoujm/pixmix/raw/main/assets/pixmix.png :width: 800px :alt: Pixmix sketch :align: center """ def __init__( self, dataset, mixing_set, beta=3, aug_severity=3, k=4, std=(1.0, 1.0, 1.0), mean=(1.0, 1.0, 1.0), ): """ :param dataset: original dataset :param mixing_set: dataset used for mixing :param beta: mixing coefficient :param aug_severity: severity used for augmentation primitives :param k: number of mixing iterations :param mean: used for normalization :param std: used for normalization """ self.dataset = dataset self.mixing_set = mixing_set self.normalize = torchvision.transforms.Normalize(std=std, mean=mean) self.mixings = [add, multiply] self.augmentations = [ lambda x, y: F.equalize(x), lambda x, y: F.autocontrast(x), F.posterize, F.rotate, F.solarize, ] self.beta = beta self.aug_severity = aug_severity self.k = k def __getitem__(self, i): x, y = self.dataset[i] rnd_idx = np.random.choice(len(self.mixing_set)) mixing_pic, _ = self.mixing_set[rnd_idx] return self._pixmix(x, mixing_pic), y def __len__(self): return len(self.dataset) def _pixmix(self, orig, mixing_pic): """ :param orig: original image :param mixing_pic: picture to mix in """ # first, apply one of the augmentations with 50% chance to the original image # TODO: make probability configurable? if np.random.random() < 0.5: mixed = to_tensor(self.augment_input(orig, severity=self.aug_severity)) else: mixed = to_tensor(orig) # then k times: create an augmented copy of the original, use use the mixing pic # mix whatever you generated into the original image for _ in range(np.random.randint(self.k + 1)): if np.random.random() < 0.5: aug_image_copy = to_tensor(self.augment_input(orig, severity=self.aug_severity)) else: aug_image_copy = to_tensor(mixing_pic) # mix current image and augmented copy mixed_op = np.random.choice(self.mixings) mixed = mixed_op(mixed, aug_image_copy, self.beta) mixed = torch.clip(mixed, 0, 1) return self.normalize(mixed) def augment_input(self, image, severity): op = np.random.choice(self.augmentations) return op(image.copy(), severity)