.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/segmentation/street.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_segmentation_street.py: StreetHazards ------------------------- We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the :class:`StreetHazards`. We then use the :class:`EnergyBased` OOD detector. .. note :: Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size. .. warning :: The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing. .. note :: The license of the model originally used for the street hazards dataset is not compatible with ``pytorch-ood``. This prevents us from re-using the implementation from the original repository. .. GENERATED FROM PYTHON SOURCE LINES 22-42 .. code-block:: Python :lineno-start: 22 import segmentation_models_pytorch as smp import torch from segmentation_models_pytorch.encoders import get_preprocessing_fn from segmentation_models_pytorch.metrics import iou_score from torch.utils.data import DataLoader from torchvision.transforms.functional import pad, to_tensor from pytorch_ood.dataset.img import StreetHazards from pytorch_ood.detector import EnergyBased from pytorch_ood.utils import OODMetrics, fix_random_seed device = "cuda:0" batch_size = 4 num_epochs = 1 fix_random_seed(12345) g = torch.Generator() g.manual_seed(0) .. GENERATED FROM PYTHON SOURCE LINES 43-44 Setup preprocessing .. GENERATED FROM PYTHON SOURCE LINES 44-59 .. code-block:: Python :lineno-start: 44 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet") def my_transform(img, target): img = to_tensor(img)[:3, :, :] # drop 4th channel img = torch.moveaxis(img, 0, -1) img = preprocess_input(img) img = torch.moveaxis(img, -1, 0) # size must be divisible by 32, so we pad the image. img = pad(img, [0, 8]).float() target = pad(target, [0, 8]) return img, target .. GENERATED FROM PYTHON SOURCE LINES 60-61 Setup datasets .. GENERATED FROM PYTHON SOURCE LINES 61-69 .. code-block:: Python :lineno-start: 61 dataset = StreetHazards( root="data", subset="train", transform=my_transform, download=True ) dataset_test = StreetHazards( root="data", subset="test", transform=my_transform, download=True ) .. GENERATED FROM PYTHON SOURCE LINES 70-71 Setup model .. GENERATED FROM PYTHON SOURCE LINES 71-78 .. code-block:: Python :lineno-start: 71 model = smp.FPN( encoder_name="resnet50", encoder_weights="imagenet", in_channels=3, classes=13, ).to(device) .. GENERATED FROM PYTHON SOURCE LINES 79-80 Train model for some epochs .. GENERATED FROM PYTHON SOURCE LINES 80-121 .. code-block:: Python :lineno-start: 80 criterion = smp.losses.DiceLoss(mode="multiclass") optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001) loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=10, worker_init_fn=fix_random_seed, generator=g, ) ious = [] loss_ema = 0 ioe_ema = 0 for epoch in range(num_epochs): for n, (x, y) in enumerate(loader): optimizer.zero_grad() y, x = y.to(device), x.to(device) y_hat = model(x) loss = criterion(y_hat, y) loss.backward() optimizer.step() tp, fp, fn, tn = smp.metrics.get_stats( y_hat.softmax(dim=1).max(dim=1).indices.long(), y.long(), mode="multiclass", num_classes=13, ) iou = iou_score(tp, fp, fn, tn) loss_ema = 0.8 * loss_ema + 0.2 * loss.item() ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item() if n % 10 == 0: print( f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}" ) .. GENERATED FROM PYTHON SOURCE LINES 122-123 Evaluate .. GENERATED FROM PYTHON SOURCE LINES 123-144 .. code-block:: Python :lineno-start: 123 print("Evaluating") model.eval() loader = DataLoader( dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g ) detector = EnergyBased(model) metrics = OODMetrics(mode="segmentation") with torch.no_grad(): for n, (x, y) in enumerate(loader): y, x = y.to(device), x.to(device) o = detector(x) # undo padding o = pad(o, [-8, -8]) y = pad(y, [-8, -8]) metrics.update(o, y) print(metrics.compute()) .. GENERATED FROM PYTHON SOURCE LINES 145-147 Output: {'AUROC': 0.8069181442260742, 'AUPR-IN': 0.07396415621042252, 'AUPR-OUT': 0.9966945648193359, 'FPR95TPR': 0.7595465183258057} .. _sphx_glr_download_auto_examples_segmentation_street.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: street.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: street.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_