.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/segmentation/street_entropic.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_segmentation_street_entropic.py: StreetHazards with Entropic Loss ------------------------------------- We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the :class:`StreetHazards` **test set** using the supervised :class:`EntropicOpenSetLoss`. We then use the :class:`Entropy` OOD detector. This setup is merely made to demonstrate how to train a supervised anomaly segmentation model with this loss function. .. warning :: We train on the test set, as it contains examples of anomalies. The results will not be meaningful. .. note :: Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size. .. GENERATED FROM PYTHON SOURCE LINES 23-45 .. code-block:: Python :lineno-start: 24 import segmentation_models_pytorch as smp import torch from segmentation_models_pytorch.encoders import get_preprocessing_fn from segmentation_models_pytorch.metrics import iou_score from torch.utils.data import DataLoader from torchvision.transforms.functional import pad, to_tensor from pytorch_ood.dataset.img import StreetHazards from pytorch_ood.detector import Entropy from pytorch_ood.loss import EntropicOpenSetLoss from pytorch_ood.utils import OODMetrics, fix_random_seed device = "cuda:0" batch_size = 4 num_epochs = 1 fix_random_seed(12345) g = torch.Generator() g.manual_seed(0) .. GENERATED FROM PYTHON SOURCE LINES 46-47 Setup preprocessing .. GENERATED FROM PYTHON SOURCE LINES 47-62 .. code-block:: Python :lineno-start: 47 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet") def my_transform(img, target): img = to_tensor(img)[:3, :, :] # drop 4th channel img = torch.moveaxis(img, 0, -1) img = preprocess_input(img) img = torch.moveaxis(img, -1, 0) # size must be divisible by 32, so we pad the image. img = pad(img, [0, 8]).float() target = pad(target, [0, 8]) return img, target .. GENERATED FROM PYTHON SOURCE LINES 63-64 Setup datasets, train on ood images for demonstration purposes. .. GENERATED FROM PYTHON SOURCE LINES 64-68 .. code-block:: Python :lineno-start: 64 dataset = StreetHazards(root="data", subset="test", transform=my_transform, download=True) dataset_test = StreetHazards(root="data", subset="test", transform=my_transform, download=True) .. GENERATED FROM PYTHON SOURCE LINES 69-70 Setup model .. GENERATED FROM PYTHON SOURCE LINES 70-77 .. code-block:: Python :lineno-start: 70 model = smp.FPN( encoder_name="resnet50", encoder_weights="imagenet", in_channels=3, classes=13, ).to(device) .. GENERATED FROM PYTHON SOURCE LINES 78-79 Train model for some epochs .. GENERATED FROM PYTHON SOURCE LINES 79-121 .. code-block:: Python :lineno-start: 80 criterion = EntropicOpenSetLoss() optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001) loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=10, worker_init_fn=fix_random_seed, generator=g, ) ious = [] loss_ema = 0 ioe_ema = 0 for epoch in range(num_epochs): for n, (x, y) in enumerate(loader): optimizer.zero_grad() y, x = y.to(device), x.to(device) y_hat = model(x) loss = criterion(y_hat, y) loss.backward() optimizer.step() tp, fp, fn, tn = smp.metrics.get_stats( y_hat.softmax(dim=1).max(dim=1).indices.long(), y.long(), mode="multiclass", num_classes=13, ) iou = iou_score(tp, fp, fn, tn) loss_ema = 0.8 * loss_ema + 0.2 * loss.item() ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item() if n % 10 == 0: print( f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}" ) .. GENERATED FROM PYTHON SOURCE LINES 122-123 Evaluate .. GENERATED FROM PYTHON SOURCE LINES 123-143 .. code-block:: Python :lineno-start: 123 print("Evaluating") model.eval() loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g) detector = Entropy(model) metrics = OODMetrics(mode="segmentation") with torch.no_grad(): for n, (x, y) in enumerate(loader): y, x = y.to(device), x.to(device) o = detector(x) # undo padding o = pad(o, [0, -8]) y = pad(y, [0, -8]) metrics.update(o, y) print(metrics.compute()) .. GENERATED FROM PYTHON SOURCE LINES 144-151 Output: +----------------------------+----------+--------+--------+---------+-----------+----------+ | Dataset | Detector | AUROC | AUTC | AUPR-IN | AUPR-OUT | FPR95TPR | +============================+==========+========+========+=========+===========+==========+ | Streethazards+Entropic-Loss| Entropy | 97.23 | 18.69 | 99.98 | 39.51 | 10.43 | +----------------------------+----------+--------+--------+---------+-----------+----------+ .. _sphx_glr_download_auto_examples_segmentation_street_entropic.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: street_entropic.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: street_entropic.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: street_entropic.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_