"""
"""
import torch
import torch.nn as nn
from ..loss.crossentropy import cross_entropy
from ..utils import apply_reduction, is_known, is_unknown
def _energy(logits: torch.Tensor) -> torch.Tensor:
return -torch.logsumexp(logits, dim=1)
[docs]
class EnergyRegularizedLoss(nn.Module):
"""
Augments the cross-entropy by a regularization term
that aims to increase the energy gap between ID and OOD samples.
This term is defined as
.. math::
\\mathcal{L}(x, y) = \\alpha
\\Biggl \\lbrace
{
\\max(0, E(x) - m_{in})^2 \\quad \\quad \\quad \\quad \\quad \\quad \\text{if } y \\geq 0
\\atop
\\max(0, m_{out} - E(x))^2 \\quad \\quad \\quad \\quad \\quad \\text{ otherwise }
}
where :math:`E(x) = - \\log(\\sum_i e^{f_i(x)} )` is the energy of :math:`x`.
:see Paper:
`NeurIPS <https://proceedings.neurips.cc/paper/2020/file/f5496252609c43eb8a3d147ab9b9c006-Paper.pdf>`__
:see Implementation: `GitHub <https://github.com/wetliu/energy_ood>`__
"""
def __init__(
self,
alpha: float = 1.0,
margin_in: float = -1.0,
margin_out: float = -1.0,
reduction: str = "mean",
):
"""
:param alpha: weighting parameter
:param margin_in: margin energy :math:`m_{in}` for ID data
:param margin_out: margin energy :math:`m_{out}` for OOD data
:param reduction: can be one of ``none``, ``mean``, ``sum``
"""
super(EnergyRegularizedLoss, self).__init__()
self.m_in = margin_in
self.m_out = margin_out
self.alpha = alpha
self.reduction = reduction
[docs]
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Calculates weighted sum of cross-entropy and the energy regularization term.
:param logits: logits
:param targets: labels
"""
regularization = self._regularization(logits, targets)
nll = cross_entropy(logits, targets, reduction="none")
return apply_reduction(nll + self.alpha * regularization, reduction=self.reduction)
def _regularization(self, logits: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
energy = torch.zeros(logits.shape[0]).to(logits.device)
# for classification
if len(logits.shape) == 2:
known = is_known(y)
energy[~known] = (self.m_out - _energy(logits[is_unknown(y)])).relu().pow(2)
if known.any():
energy[known] = (_energy(logits[is_known(y)]) - self.m_in).relu().pow(2)
if (~known).any():
energy[~known] = (self.m_out - _energy(logits[is_unknown(y)])).relu().pow(2)
# for segmentation
elif len(logits.shape) == 4:
logits_form = logits.permute(0, 2, 3, 1)
if is_known(y).any():
energy_in = (_energy(logits_form[is_known(y)]) - self.m_in).relu().pow(2).mean()
else:
energy_in = 0
if is_unknown(y).any():
energy_out = (
(_energy(self.m_out - logits_form[is_unknown(y)])).relu().pow(2).mean()
)
else:
energy_out = 0
energy = energy_in + energy_out
else:
raise ValueError(f"Unsupported input shape: {logits.shape}")
return energy