import logging
from typing import Optional
import torch
from torch import nn
from ..utils import apply_reduction, contains_unknown, is_unknown
from .crossentropy import cross_entropy
log = logging.getLogger(__name__)
[docs]
class OutlierExposureLoss(nn.Module):
"""
Loss from the paper *Deep Anomaly Detection With Outlier Exposure*.
While the formulation in the original paper is very general, this module implements the exact loss that
was used in the corresponding experiments.
The loss is defined as
.. math::
\\mathcal{L}(x, y)
=
\\Biggl \\lbrace
{
-\\log \\sigma_y(f(x)) \\quad \\quad \\quad \\quad \\quad \\quad \\quad \\quad \\quad \\quad \\text{if } y \\geq 0
\\atop
\\alpha (\\sum_{c=1}^C f(x)_c - \\log(\\sum_{c=1}^C e^{f(x)_c})) \\quad \\text{ otherwise }
}
where :math:`C` is the number of classes, :math:`\\alpha` is a hyper parameter, and :math:`\\sigma_y`
denotes the :math:`y^{th}` softmax output.
:see Paper: `ArXiv <https://arxiv.org/pdf/1812.04606v1.pdf>`__
:see Implementation: `GitHub <https://github.com/hendrycks/outlier-exposure>`__
"""
def __init__(self, alpha: float = 0.5, reduction: Optional[str] = "mean"):
"""
:param alpha: weighting coefficient :math:`\\alpha`
:param reduction: reduction method, one of ``mean``, ``sum`` or ``none``
"""
super(OutlierExposureLoss, self).__init__()
self.alpha = alpha
self.reduction = reduction
[docs]
def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
:param logits: class logits for predictions
:param target: labels for predictions
:return: loss
"""
# for classification
if len(logits.shape) == 2:
loss_oe = torch.zeros(logits.shape[0], device=logits.device)
loss_ce = cross_entropy(logits, target, reduction=None)
if contains_unknown(target):
unknown = is_unknown(target)
loss_oe[unknown] = -(
logits[unknown].mean(dim=1) - torch.logsumexp(logits[unknown], dim=1)
)
return apply_reduction(loss_ce + self.alpha * loss_oe, reduction=self.reduction)
# for segmentation
elif len(logits.shape) == 4:
loss_ce = cross_entropy(logits, target, reduction=None)
# move class axis to the back
logits = logits.permute(0, 2, 3, 1)
if contains_unknown(target):
unknown = is_unknown(target)
# mean over class axis
loss_oe = -(logits.mean(dim=-1) - torch.logsumexp(logits, dim=-1))
fp32zero = torch.zeros((1,), dtype=torch.float, device=logits.device)
loss_oe = torch.where(target.float() < 0, loss_oe, fp32zero)
else:
loss_oe = torch.zeros(logits.shape[:3], device=logits.device)
return apply_reduction(loss_ce + self.alpha * loss_oe, reduction=self.reduction)
else:
raise ValueError(f"Unsupported input shape: {logits.shape}")