.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/benchmarks/manual/cifar100_baseline.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_benchmarks_manual_cifar100_baseline.py: CIFAR 100 ============================== The evaluation is the same as for CIFAR 10. +------------------+-------+---------+----------+----------+ | Detector | AUROC | AUPR-IN | AUPR-OUT | FPR95TPR | +==================+=======+=========+==========+==========+ | SHE | 59.43 | 77.44 | 68.37 | 100.00 | +------------------+-------+---------+----------+----------+ | Mahalanobis | 75.35 | 81.59 | 65.62 | 58.87 | +------------------+-------+---------+----------+----------+ | MSP | 78.78 | 82.37 | 71.34 | 57.67 | +------------------+-------+---------+----------+----------+ | Mahalanobis+ODIN | 79.24 | 84.58 | 68.69 | 55.91 | +------------------+-------+---------+----------+----------+ | KLMatching | 79.88 | 83.53 | 68.23 | 60.02 | +------------------+-------+---------+----------+----------+ | ODIN | 80.80 | 83.96 | 73.40 | 54.92 | +------------------+-------+---------+----------+----------+ | Entropy | 81.19 | 84.61 | 73.08 | 56.49 | +------------------+-------+---------+----------+----------+ | ViM | 81.73 | 85.87 | 72.91 | 49.86 | +------------------+-------+---------+----------+----------+ | RMD | 83.23 | 86.94 | 74.56 | 50.55 | +------------------+-------+---------+----------+----------+ | MaxLogit | 84.70 | 86.66 | 78.33 | 47.40 | +------------------+-------+---------+----------+----------+ | EnergyBased | 85.00 | 86.88 | 78.69 | 46.70 | +------------------+-------+---------+----------+----------+ | DICE | 85.35 | 87.32 | 78.99 | 46.17 | +------------------+-------+---------+----------+----------+ .. GENERATED FROM PYTHON SOURCE LINES 37-75 .. code-block:: Python :lineno-start: 37 import pandas as pd # additional dependency, used here for convenience import torch from torch.utils.data import DataLoader from torchvision.datasets import CIFAR100, CIFAR10, MNIST, FashionMNIST from pytorch_ood.dataset.img import ( LSUNCrop, LSUNResize, Textures, TinyImageNetCrop, TinyImageNetResize, Places365, TinyImageNet, ) from pytorch_ood.detector import ( ODIN, EnergyBased, Entropy, KLMatching, Mahalanobis, MaxLogit, MaxSoftmax, ViM, RMD, DICE, SHE, ) from pytorch_ood.model import WideResNet from pytorch_ood.utils import OODMetrics, ToUnknown, fix_random_seed device = "cuda:0" fix_random_seed(123) # setup preprocessing trans = WideResNet.transform_for("cifar100-pt") norm_std = WideResNet.norm_std_for("cifar100-pt") .. GENERATED FROM PYTHON SOURCE LINES 76-77 Setup datasets .. GENERATED FROM PYTHON SOURCE LINES 77-101 .. code-block:: Python :lineno-start: 77 dataset_in_test = CIFAR100(root="data", train=False, transform=trans, download=True) # create all OOD datasets ood_datasets = [ Textures, TinyImageNetCrop, TinyImageNetResize, LSUNCrop, LSUNResize, Places365, CIFAR10, MNIST, FashionMNIST, ] datasets = {} for ood_dataset in ood_datasets: dataset_out_test = ood_dataset( root="data", transform=trans, target_transform=ToUnknown(), download=True ) test_loader = DataLoader( dataset_in_test + dataset_out_test, batch_size=256, num_workers=12 ) datasets[ood_dataset.__name__] = test_loader .. GENERATED FROM PYTHON SOURCE LINES 102-103 **Stage 1**: Create DNN with pre-trained weights from the Hendrycks baseline paper .. GENERATED FROM PYTHON SOURCE LINES 103-126 .. code-block:: Python :lineno-start: 103 print("STAGE 1: Creating a Model") model = WideResNet(num_classes=100, pretrained="cifar100-pt").eval().to(device) # Stage 2: Create OOD detector print("STAGE 2: Creating OOD Detectors") detectors = {} detectors["Entropy"] = Entropy(model) detectors["ViM"] = ViM(model.features, d=64, w=model.fc.weight, b=model.fc.bias) detectors["Mahalanobis+ODIN"] = Mahalanobis( model.features, norm_std=norm_std, eps=0.002 ) detectors["Mahalanobis"] = Mahalanobis(model.features) detectors["KLMatching"] = KLMatching(model) detectors["SHE"] = SHE(model.features, model.fc) detectors["MSP"] = MaxSoftmax(model) detectors["EnergyBased"] = EnergyBased(model) detectors["MaxLogit"] = MaxLogit(model) detectors["ODIN"] = ODIN(model, norm_std=norm_std, eps=0.002) detectors["DICE"] = DICE( model=model.features, w=model.fc.weight, b=model.fc.bias, p=0.65 ) detectors["RMD"] = RMD(model.features) .. GENERATED FROM PYTHON SOURCE LINES 127-128 **Stage 2**: fit detectors to training data (some require this, some do not) .. GENERATED FROM PYTHON SOURCE LINES 128-136 .. code-block:: Python :lineno-start: 128 print(f"> Fitting {len(detectors)} detectors") loader_in_train = DataLoader( CIFAR100(root="data", train=True, transform=trans), batch_size=256, num_workers=12 ) for name, detector in detectors.items(): print(f"--> Fitting {name}") detector.fit(loader_in_train, device=device) .. GENERATED FROM PYTHON SOURCE LINES 137-138 **Stage 3**: Evaluate Detectors .. GENERATED FROM PYTHON SOURCE LINES 138-158 .. code-block:: Python :lineno-start: 138 print(f"STAGE 3: Evaluating {len(detectors)} detectors on {len(datasets)} datasets.") results = [] with torch.no_grad(): for detector_name, detector in detectors.items(): print(f"> Evaluating {detector_name}") for dataset_name, loader in datasets.items(): print(f"--> {dataset_name}") metrics = OODMetrics() for x, y in loader: metrics.update(detector(x.to(device)), y.to(device)) r = {"Detector": detector_name, "Dataset": dataset_name} r.update(metrics.compute()) results.append(r) # calculate mean scores over all datasets, use percent df = pd.DataFrame(results) mean_scores = df.groupby("Detector").mean() * 100 print(mean_scores.sort_values("AUROC").to_csv(float_format="%.2f")) .. _sphx_glr_download_auto_examples_benchmarks_manual_cifar100_baseline.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: cifar100_baseline.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: cifar100_baseline.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_