Source code for pytorch_ood.loss.vos

"""
Parts of this code are taken from
 https://github.com/deeplearning-wisc/vos/blob/a449b03c7d6e120087007f506d949569c845b2ec/classification/CIFAR/train_virtual.py

"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from ..loss.crossentropy import cross_entropy
from ..utils import apply_reduction, is_known, is_unknown


[docs] class VOSRegLoss(nn.Module): """ Implements the loss function from *VOS: Learning what you don’t know by virtual outlier synthesis* without the synthesising of virtual outliers. The loss adds a regularization term to the cross-entropy that aims to increase the (weighted) energy gap between ID and OOD samples. The regularization term is defined as: .. math:: \\mathcal{L} = \\mathbb{E}_{v \\sim V} \\left[ -\\text {log}\\frac{1}{1+\\text{exp}^{-\\phi(E(v))}} \\right] + \\mathbb{E}_{x \\sim D} \\left[ -\\text {log} \\frac{\\text{exp}^{-\\phi(E(x))}}{1+ \\text{exp}^{-\\phi(E(x))}}\\right] where :math:`\\phi` is a possibly non-linear function, :math:`E` is the weighted energy and :math:`V` and :math:`D` are the distributions of the (possibly virtual) outliers and the ID data respectively. :see Paper: `ArXiv <https://arxiv.org/pdf/2202.01197.pdf>`__ :see Implementation: `GitHub <https://github.com/deeplearning-wisc/vos/>`__ For initialisation of :math:`\\phi` and the weights for weighted energy: .. code :: python phi = torch.nn.Linear(1, 2) weights = torch.nn.Linear(num_classes, 1) torch.nn.init.uniform_(weights.weight) criterion = VOSRegLoss(phi, weights) .. note :: This implementation does not generate synthetic outliers. For this feature, see :class:`pytorch_ood.loss.vos.VirtualOutlierSynthesizingRegLoss`. """ def __init__( self, logistic_regression: torch.nn.Linear, weights_energy: torch.nn.Linear, alpha: float = 0.1, device: str = "cpu", reduction: str = "mean", ): """ :param logistic_regression: :math:`\\phi` function. Can be for example a linear layer. :param weights_energy: neural network layer with weights for the energy :param alpha: weighting parameter :math:`\\alpha`. :param reduction: reduction method to apply, one of ``mean``, ``sum`` or ``none`` :param device: For example ``cpu`` or ``cuda:0`` """ super(VOSRegLoss, self).__init__() self.logistic_regression = logistic_regression self.weights_energy: torch.nn.Linear = weights_energy self.alpha = alpha self.device = device self.reduction = reduction self.nll = cross_entropy
[docs] def forward(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ :param logits: logits :param y: labels """ regularization = self._regularization(logits, y) loss = self.nll(logits, y, reduction=self.reduction) return apply_reduction(loss, self.reduction) + apply_reduction( self.alpha * regularization, self.reduction )
def _regularization(self, logits, y): """ :param logits: logits :param y: labels """ # Permutation depends on shape of logits if len(logits.shape) == 4: logits_form = logits.permute(0, 2, 3, 1) else: logits_form = logits energy_x_in = self._energy(logits_form[is_known(y)]) energy_v_out = self._energy(logits_form[is_unknown(y)]) return self._calculate_reg_loss(energy_x_in, energy_v_out) def _calculate_reg_loss(self, energy_score_for_fg, energy_score_for_bg): """ :param energy_score_for_fg: energy score for in-of-distribution samples :param energy_score_for_bg: energy score for out-of-distribution samples :param features: features of in-of-distribution samples :param ood_samples: out-of-distribution samples """ input_for_lr = torch.cat((energy_score_for_fg, energy_score_for_bg), -1) labels_for_lr = torch.cat( ( torch.ones(len(energy_score_for_fg)).to(self.device), torch.zeros(len(energy_score_for_bg)).to(self.device), ), -1, ) output1 = self.logistic_regression(input_for_lr.view(-1, 1)) lr_reg_loss = self.nll(output1, labels_for_lr.long()) return lr_reg_loss def _energy(self, logits, dim=1, keepdim=False): """ Numerically stable implementation of the energy calculation :param logits: logits :param dim: dimension to reduce :param keepdim: keep dimension """ m, _ = torch.max(logits, dim=dim, keepdim=True) value0 = logits - m if keepdim is False: m = m.squeeze(dim) return -( m + torch.log( torch.sum( F.relu(self.weights_energy.weight) * torch.exp(value0), dim=dim, keepdim=keepdim, ) ) )
[docs] class VirtualOutlierSynthesizingRegLoss(VOSRegLoss): """ Implements the loss function of *VOS: Learning what you don’t know by virtual outlier synthesis* with additional sampling of virtual outliers. These outliers are synthesized by fitting a gaussian to the latent features and sampling from low-likelihood regions. This alleviates the need for real outliers during training. For more information see :class:`VOS Energy-Based Loss<pytorch_ood.loss.vos.VOSRegLoss>`. :see Paper: `ArXiv <https://arxiv.org/pdf/2202.01197.pdf>`__ :see Implementation: `GitHub <https://github.com/deeplearning-wisc/vos/>`__ """ def __init__( self, logistic_regression: torch.nn.Linear, weights_energy: torch.nn.Linear, device: str, num_classes: int, num_input_last_layer: int, fc: torch.nn.Linear, alpha: float = 0.1, reduction: str = "mean", sample_number: int = 1000, select: int = 1, sample_from: int = 10000, ) -> None: """ :param logistic_regression: :math:`\\phi` function. Can be for example a linear layer. :param weights_energy: neural network layer, with weights for the energy :param device: For example ``cpu`` or ``cuda:0`` :param num_classes: number of classes :param num_input_last_layer: number of inputs in the last layer of the network :param fc: fully connected last layer of the network :param alpha: weighting parameter :param reduction: reduction method to apply, one of ``mean``, ``sum`` or ``none`` :param sample_number: number of samples that are used for virtual outlier synthesis :param select: number of highest density samples that are used for virtual outlier synthesis :param sample_from: number of samples that are used for sampling the probability distribution """ super(VirtualOutlierSynthesizingRegLoss, self).__init__( logistic_regression, weights_energy, device=device, alpha=alpha, reduction=reduction, ) self.num_classes = num_classes self.num_input_last_layer = num_input_last_layer self.fc = fc self.sample_number = sample_number self.select = select self.sample_from = sample_from self.number_dict = {} for i in range(num_classes): self.number_dict[i] = 0 self.data_dict = torch.zeros( num_classes, self.sample_number, self.num_input_last_layer ).to(self.device) self.eye_matrix = torch.eye(self.num_input_last_layer, device=self.device)
[docs] def forward(self, logits: torch.Tensor, features: torch.Tensor, y: torch.Tensor): """ :param logits: logits :param features: features :param y: labels """ # check for outlier targets (negative values) if torch.any(y < 0): raise ValueError( "Outlier targets in VirtualOutlierSynthesizingRegLoss. This loss function only supports inlier targets." ) regularization = self._regularization(logits, features, y) loss = self.nll(logits, y, reduction=self.reduction) return apply_reduction(loss, self.reduction) + apply_reduction( self.alpha * regularization, self.reduction )
def _regularization(self, prediction, features, target): """ :param prediction: logits :param features: features :param target: labels """ if len(target.shape) == 3: return self._regularization_segmentation(prediction, features, target) else: return self._regularization_classification(prediction, features, target) def _regularization_classification(self, prediction, features, target): """ :param prediction: logits :param features: features :param target: labels """ # energy regularization. sum_temp = 0 for index in range(self.num_classes): sum_temp += self.number_dict[index] lr_reg_loss = torch.zeros(1).to(self.device)[0] # case not enough samples are collected --> fill data_dict if sum_temp != self.num_classes * self.sample_number: target_numpy = target.cpu().data.numpy() for index in range(len(target)): dict_key = target_numpy[index] # get class id if self.number_dict[dict_key] < self.sample_number: self.data_dict[dict_key][self.number_dict[dict_key]] = features[index].detach() self.number_dict[dict_key] += 1 # case enough samples collected else: # update queue with new data target_numpy = target.cpu().data.numpy() for index in range(len(target)): # get class id dict_key = target_numpy[index] # update queue self.data_dict[dict_key] = torch.cat( ( self.data_dict[dict_key][1:], features[index].detach().view(1, -1), ), 0, ) # the covariance finder needs the data to be centered. for index in range(self.num_classes): if index == 0: X = self.data_dict[index] - self.data_dict[index].mean(0) mean_embed_id = self.data_dict[index].mean(0).view(1, -1) else: X = torch.cat((X, self.data_dict[index] - self.data_dict[index].mean(0)), 0) mean_embed_id = torch.cat( (mean_embed_id, self.data_dict[index].mean(0).view(1, -1)), 0 ) # add the variance. temp_precision = torch.mm(X.t(), X) / len(X) temp_precision += 0.0001 * self.eye_matrix # create distributions for each class for index in range(self.num_classes): new_dis = torch.distributions.multivariate_normal.MultivariateNormal( mean_embed_id[index], covariance_matrix=temp_precision ) negative_samples = new_dis.rsample((self.sample_from,)) prob_density = new_dis.log_prob(negative_samples) # breakpoint() # index_prob = (prob_density < - self.threshold).nonzero().view(-1) # keep the data in the low density area. cur_samples, index_prob = torch.topk(-prob_density, self.select) if index == 0: ood_samples = negative_samples[index_prob] else: ood_samples = torch.cat((ood_samples, negative_samples[index_prob]), 0) if len(ood_samples) != 0: # add some gaussian noise energy_score_for_fg = self._energy(prediction, 1) predictions_ood = self.fc(ood_samples) energy_score_for_bg = self._energy(predictions_ood, 1) lr_reg_loss = self._calculate_reg_loss(energy_score_for_fg, energy_score_for_bg) return lr_reg_loss def _regularization_segmentation(self, prediction, features, target): """ :param prediction: logits :param features: features :param target: labels """ raise NotImplementedError("Segmentation not implemented yet")