"""
.. image:: https://img.shields.io/badge/classification-yes-brightgreen?style=flat-square
:alt: classification badge
.. image:: https://img.shields.io/badge/segmentation-yes-brightgreen?style=flat-square
:alt: segmentation badge
.. autoclass:: pytorch_ood.detector.ASH
:members:
:inherited-members:
:show-inheritance:
:exclude-members: fit
"""
import logging
from typing import Callable, TypeVar
import numpy as np
import torch.nn
from torch import Tensor
from ..api import FeatureMapsDetector
from .energy import EnergyBased
log = logging.getLogger(__name__)
Self = TypeVar("Self")
def ash_b(x: Tensor, percentile: float = 0.65) -> Tensor:
assert x.dim() == 4
b, c, h, w = x.shape
# calculate the sum of the input per sample
s1 = x.sum(dim=[1, 2, 3])
n = x.shape[1:].numel()
k = n - int(np.round(n * percentile))
t = x.view((b, c * h * w))
v, i = torch.topk(t, k, dim=1)
fill = s1 / k
fill = fill.unsqueeze(dim=1).expand(v.shape)
t.zero_().scatter_(dim=1, index=i, src=fill)
return x
def ash_p(x: Tensor, percentile: float = 0.65) -> Tensor:
assert x.dim() == 4
b, c, h, w = x.shape
n = x.shape[1:].numel()
k = n - int(np.round(n * percentile))
t = x.view((b, c * h * w))
v, i = torch.topk(t, k, dim=1)
t.zero_().scatter_(dim=1, index=i, src=v)
return x
def ash_s(x: Tensor, percentile: float = 0.65) -> Tensor:
assert x.dim() == 4
b, c, h, w = x.shape
# calculate the sum of the input per sample
s1 = x.sum(dim=[1, 2, 3])
n = x.shape[1:].numel()
k = n - int(np.round(n * percentile))
t = x.view((b, c * h * w))
v, i = torch.topk(t, k, dim=1)
t.zero_().scatter_(dim=1, index=i, src=v)
# calculate new sum of the input per sample after pruning
s2 = x.sum(dim=[1, 2, 3])
# apply sharpening
scale = s1 / s2
x = x * torch.exp(scale[:, None, None, None])
return x
[docs]
class ASH(FeatureMapsDetector):
"""
Implements ASH from the paper
*Extremely Simple Activation Shaping for Out-of-Distribution Detection*.
ASH prunes the activations in some layer of the network (backbone) by removing a certain percentile of
the highest activations. The remaining activations are modified, depending on the particular variant selected, and
propagated through the remainder (head) of the network.
Then uses the energy based outlier score.
This approach has been shown to increase OOD detection rates while maintaining ID accuracy.
* ASH-P: only prune, do not modify
* ASH-B: binarize remaining activations
* ASH-S: rescale remaining activations
The paper applies ASH after the last average pooling layer.
Example Code:
.. code :: python
model = WideResNet()
detector = ASH(
backbone = model.features_before_pool,
head = model.forward_from_before_pool,
detector=EnergyBased.score
)
scores = detector(images)
:see Paper: `ICLR 2023 <https://openreview.net/pdf?id=ndYXTEL6cZz>`__
:see Website: `github.io <https://andrijazz.github.io/ash/>`__
"""
variants = {
"ash-s": ash_s,
"ash-p": ash_p,
"ash-b": ash_b,
}
def __init__(
self,
backbone: Callable[[Tensor], Tensor],
head: Callable[[Tensor], Tensor],
variant="ash-s",
percentile: float = 0.65,
detector: Callable[[Tensor], Tensor] = None,
):
"""
:param variant: one of ``ash-p``, ``ash-b``, ``ash-s``
:param backbone: first part of model to use, should output feature maps
:param head: second part of model used after applying ash, should output logits
:param percentile: amount of activations to modify
:param detector: detector that maps model outputs to outlier scores. Default is Energy based.
"""
assert variant in self.variants
self.backbone = backbone
self.head = head
self.percentile = percentile
self.ash: Callable[[Tensor, float], Tensor] = self.variants[variant]
self.detector = detector or EnergyBased.score
[docs]
def predict(self, x: Tensor) -> Tensor:
"""
:param x: input, will be passed through network
"""
x = self.backbone(x)
return self.predict_feature_maps(x)
def predict_feature_maps(self, x: Tensor) -> Tensor:
x = self.ash(x, self.percentile)
x = self.head(x)
return self.detector(x)