GradNormKL

Running GradNormKL on CIFAR 10.

This method computes the \(\ell_1\)-norm of the gradient of the KL divergence between the softmax output and a uniform distribution, with respect to the weights of the final classification head. In-distribution inputs tend to produce more peaked predictions (larger divergence from uniform), yielding higher gradient norms. The score is negated so that higher values indicate OOD.

Unlike GradNorm, no labeled OOD data or additional classifier is required.

17 import logging
18
19 from torch.utils.data import DataLoader
20 from torchvision.datasets import CIFAR10
21
22 from pytorch_ood.dataset.img import Textures
23 from pytorch_ood.detector import GradNormKL
24 from pytorch_ood.model import WideResNet
25 from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed
26
27 logging.basicConfig(level=logging.INFO)
28
29 fix_random_seed(123)
30
31 device = "cuda"

Setup preprocessing and data

35 trans = WideResNet.transform_for("cifar10-pt")
36
37 dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans)
38 dataset_out_test = Textures(
39     root="data", download=True, transform=trans, target_transform=ToUnknown()
40 )
41
42 test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=64, num_workers=4)

Stage 1: Load pre-trained WideResNet for CIFAR-10. Disable gradients for the backbone; only the final FC layer needs them.

47 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval()
48
49 model.requires_grad_(False)
50 model.fc.requires_grad_(True)

Stage 2: Create detector — no fitting required.

54 detector = GradNormKL(model, param_filter=lambda name: name.startswith("fc"))

Stage 3: Evaluate

58 print("Testing...")
59
60 metrics = OODMetrics()
61 for x, y in test_loader:
62     metrics.update(detector(x.to(device)), y)
63
64 print(metrics.compute())

Gallery generated by Sphinx-Gallery