.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/osr/osr.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_osr_osr.py: CIFAR10 ------------------------- Open Set Simulation on CIFAR 10 .. GENERATED FROM PYTHON SOURCE LINES 8-25 .. code-block:: Python :lineno-start: 8 import torch.nn from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 from torch.nn import CrossEntropyLoss from tqdm import tqdm from torchmetrics import Accuracy from pytorch_ood.dataset.ossim import DynamicOSS from pytorch_ood.model import WideResNet from pytorch_ood.detector import MaxSoftmax from pytorch_ood.utils import fix_random_seed, TargetMapping, OODMetrics, is_known device = "cuda:0" num_epochs = 10 fix_random_seed(12345) .. GENERATED FROM PYTHON SOURCE LINES 26-27 Setup preprocessing .. GENERATED FROM PYTHON SOURCE LINES 27-30 .. code-block:: Python :lineno-start: 27 trans = WideResNet.transform_for("cifar10-pt") norm_std = WideResNet.norm_std_for("cifar10-pt") .. GENERATED FROM PYTHON SOURCE LINES 31-32 Setup datasets .. GENERATED FROM PYTHON SOURCE LINES 32-36 .. code-block:: Python :lineno-start: 32 dataset_1 = CIFAR10(root="data", train=True, transform=trans, download=True) dataset_2 = CIFAR10(root="data", train=False, transform=trans, download=True) dataset = dataset_1 + dataset_2 .. GENERATED FROM PYTHON SOURCE LINES 37-39 Create DNN with pre-trained on a downscaled version of the image net, excluding cifar images adjust it to output 7 logits .. GENERATED FROM PYTHON SOURCE LINES 39-44 .. code-block:: Python :lineno-start: 39 print("Creating a Model") model = WideResNet(num_classes=1000, pretrained="imagenet32-nocifar") model.fc = torch.nn.Linear(model.fc.in_features, 7) model.to(device) .. GENERATED FROM PYTHON SOURCE LINES 45-47 Create open set simulation and dataloaders We use 3 classes as unknown unknown (OOD), and a data split of 90% train and 10% test .. GENERATED FROM PYTHON SOURCE LINES 47-71 .. code-block:: Python :lineno-start: 47 ossim = DynamicOSS( dataset=dataset, train_size=0.9, val_size=0.0, test_size=0.1, kuc=0, uuc_val=0, uuc_test=3, seed=1, ) print(f"Known Classes: {ossim.kkc}") print(f"Unknown Classes: {ossim.uuc}") # create class remapping class_mapping = TargetMapping(known=ossim.kkc, unknown=ossim.uuc) train_loader = DataLoader(ossim.train_dataset(), batch_size=32, num_workers=12) test_loader = DataLoader(ossim.test_dataset(), batch_size=32, num_workers=12) criterion = CrossEntropyLoss() opti = torch.optim.Adam(model.parameters(), lr=0.001) .. GENERATED FROM PYTHON SOURCE LINES 72-73 Define function for testing .. GENERATED FROM PYTHON SOURCE LINES 73-98 .. code-block:: Python :lineno-start: 73 @torch.no_grad() def test(): metrics = OODMetrics() acc = Accuracy(task="multiclass", num_classes=7).to(device) model.eval() for x, y in tqdm(test_loader): # do not forget to remap class labels y = torch.tensor([class_mapping(i.item()) for i in y]) y = y.to(device) x = x.to(device) z = model(x) metrics.update(MaxSoftmax.score(z), y) known = is_known(y) if known.any(): acc.update(z[known].argmax(dim=1), y[known]) print(metrics.compute()) print(acc.compute().item()) .. GENERATED FROM PYTHON SOURCE LINES 99-100 Start training .. GENERATED FROM PYTHON SOURCE LINES 100-126 .. code-block:: Python :lineno-start: 100 for epoch in range(num_epochs): bar = tqdm(train_loader) model.train() loss_ema = None for x, y in bar: # do not forget to remap class labels y = torch.tensor([class_mapping(i.item()) for i in y]) y = y.to(device) x = x.to(device) z = model(x) loss = criterion(z, y) opti.zero_grad() loss.backward() opti.step() loss_ema = ( loss_ema * 0.95 + loss.item() * 0.05 if loss_ema is not None else loss.item() ) bar.set_postfix_str(f"loss: {loss_ema:.2f}") test() .. _sphx_glr_download_auto_examples_osr_osr.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: osr.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: osr.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_