.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/loss/supervised/outlier_exposure.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_loss_supervised_outlier_exposure.py: Outlier Exposure ------------------------- We train a model with :class:`Outlier Exposure ` on the CIFAR10. We can use a model pre-trained on the :math:`32 \times 32` resized version of the ImageNet as a foundation. As outlier data, we use :class:`TinyImages300k `, a cleaned version of the TinyImages database, which contains random images scraped from the internet. .. GENERATED FROM PYTHON SOURCE LINES 14-32 .. code-block:: Python :lineno-start: 14 import torch import torchvision.transforms as tvt from torch.optim import Adam from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from pytorch_ood.dataset.img import Textures, TinyImages300k from pytorch_ood.detector import MaxSoftmax from pytorch_ood.loss import OutlierExposureLoss from pytorch_ood.model import WideResNet from pytorch_ood.utils import OODMetrics, ToUnknown torch.manual_seed(123) # maximum number of epochs and training iterations n_epochs = 10 device = "cuda:0" .. GENERATED FROM PYTHON SOURCE LINES 33-34 Setup preprocessing and data .. GENERATED FROM PYTHON SOURCE LINES 34-59 .. code-block:: Python :lineno-start: 34 trans = tvt.Compose([tvt.Resize(size=(32, 32)), tvt.ToTensor()]) # setup IN training data dataset_in_train = CIFAR10(root="data", train=True, download=True, transform=trans) # setup OOD training data, use ToUnknown() to mark labels as OOD # this way, outlier exposure can automatically decide if the training samples are IN or OOD dataset_out_train = TinyImages300k( root="data", download=True, transform=trans, target_transform=ToUnknown() ) # setup IN test data dataset_in_test = CIFAR10(root="data", train=False, transform=trans) # setup OOD test data, use ToUnknown() to mark labels as OOD dataset_out_test = Textures( root="data", download=True, transform=trans, target_transform=ToUnknown() ) # create data loaders train_loader = DataLoader( dataset_in_train + dataset_out_train, batch_size=64, shuffle=True ) test_loader = DataLoader(dataset_in_test + dataset_out_test, batch_size=64) .. GENERATED FROM PYTHON SOURCE LINES 60-61 Create DNN, pretrained on the imagenet excluding cifar10 classes .. GENERATED FROM PYTHON SOURCE LINES 61-72 .. code-block:: Python :lineno-start: 61 model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar") # we have to replace the final layer to account for the lower number of # classes in the CIFAR10 dataset model.fc = torch.nn.Linear(model.fc.in_features, 10) model.to(device) opti = Adam(model.parameters()) criterion = OutlierExposureLoss(alpha=0.5) .. GENERATED FROM PYTHON SOURCE LINES 73-74 Define a function to test the model .. GENERATED FROM PYTHON SOURCE LINES 74-88 .. code-block:: Python :lineno-start: 74 def test(): softmax = MaxSoftmax(model) metrics_softmax = OODMetrics() model.eval() with torch.no_grad(): for x, y in test_loader: metrics_softmax.update(softmax(x.to(device)), y) print(metrics_softmax.compute()) model.train() .. GENERATED FROM PYTHON SOURCE LINES 89-90 Start training .. GENERATED FROM PYTHON SOURCE LINES 90-100 .. code-block:: Python :lineno-start: 90 for epoch in range(n_epochs): print(f"Epoch {epoch}") for x, y in train_loader: logits = model(x.to(device)) loss = criterion(logits, y.to(device)) opti.zero_grad() loss.backward() opti.step() test() .. _sphx_glr_download_auto_examples_loss_supervised_outlier_exposure.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: outlier_exposure.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: outlier_exposure.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_