.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/detectors/pnml.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_detectors_pnml.py: PNML ============================== Comparing :class:`PNML `, :class:`MaxSoftmax `, and :class:`EnergyBased ` on a small CIFAR-10 benchmark. This example mirrors the other detector demos, but keeps fitting and evaluation on small subsets so it stays reasonably fast. .. GENERATED FROM PYTHON SOURCE LINES 12-36 .. code-block:: Python :lineno-start: 13 import logging import torch from torch.utils.data import DataLoader, Subset from torchvision.datasets import CIFAR10 from pytorch_ood.dataset.img import Textures from pytorch_ood.detector import EnergyBased, MaxSoftmax, PNML from pytorch_ood.model import WideResNet from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed logging.basicConfig(level=logging.INFO) fix_random_seed(123) device = "cuda" if torch.cuda.is_available() else "cpu" batch_size = 128 # Use subsets to keep this example lightweight while still running a real benchmark. n_train = 1_000 n_in_test = 1_000 n_out_test = 1_000 .. GENERATED FROM PYTHON SOURCE LINES 37-38 Setup preprocessing and data .. GENERATED FROM PYTHON SOURCE LINES 38-67 .. code-block:: Python :lineno-start: 38 trans = WideResNet.transform_for("cifar10-pt") dataset_train = CIFAR10(root="data", train=True, download=True, transform=trans) dataset_in_test = CIFAR10(root="data", train=False, download=True, transform=trans) dataset_out_test = Textures( root="data", download=True, transform=trans, target_transform=ToUnknown() ) train_loader = DataLoader( Subset(dataset_train, range(n_train)), batch_size=batch_size, shuffle=True, ) test_loader = DataLoader( Subset(dataset_in_test, range(n_in_test)) + Subset(dataset_out_test, range(n_out_test)), batch_size=batch_size, ) def evaluate(name, detector, loader): metrics = OODMetrics() for x, y in loader: metrics.update(detector(x.to(device)), y) print(f"{name}:") print(metrics.compute()) .. GENERATED FROM PYTHON SOURCE LINES 68-69 Stage 1: Create DNN pre-trained on CIFAR-10 .. GENERATED FROM PYTHON SOURCE LINES 69-71 .. code-block:: Python :lineno-start: 69 model = WideResNet(num_classes=10, pretrained="cifar10-pt").to(device).eval() .. GENERATED FROM PYTHON SOURCE LINES 72-73 Stage 2: Create and fit detectors .. GENERATED FROM PYTHON SOURCE LINES 73-80 .. code-block:: Python :lineno-start: 73 pnml = PNML(model.features, model.fc).to(device) msp = MaxSoftmax(model).to(device) energy = EnergyBased(model).to(device) print("Fitting...") pnml.fit(train_loader) .. GENERATED FROM PYTHON SOURCE LINES 81-82 Stage 3: Evaluate detectors .. GENERATED FROM PYTHON SOURCE LINES 82-87 .. code-block:: Python :lineno-start: 82 print("Testing...") evaluate("PNML", pnml, test_loader) evaluate("MSP", msp, test_loader) evaluate("Energy", energy, test_loader) .. _sphx_glr_download_auto_examples_detectors_pnml.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: pnml.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: pnml.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: pnml.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_