RoadAnomaly

We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the Citiscapes Dataset (please download it before and put gtFine and leftImg8bit it into the data/cityscapes folder). This model is evaluated using the EnergyBased OOD detector on the original RoadAnomaly dataset and both datasets of the SegmentMeIfYouCan benchmark: RoadAnomaly21 and RoadObstacles21.

Note

Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size.

Warning

The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.

20 import segmentation_models_pytorch as smp
21 import torch
22 from segmentation_models_pytorch.encoders import get_preprocessing_fn
23 from segmentation_models_pytorch.metrics import iou_score
24 from torch.utils.data import DataLoader
25 from torchvision.transforms.functional import pad, to_tensor
26 from torchvision.datasets import Cityscapes
27 from PIL import Image
28
29 from pytorch_ood.dataset.img import RoadAnomaly, SegmentMeIfYouCan
30 from pytorch_ood.detector import EnergyBased
31 from pytorch_ood.utils import OODMetrics, fix_random_seed
32
33 device = "cuda:0"
34 batch_size = 4
35 num_epochs = 1
36 classes = 34
37 fix_random_seed(12345)
38 g = torch.Generator()
39 g.manual_seed(0)

Setup preprocessing

44 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
45
46
47 def my_transform(img, target):
48     img = to_tensor(img)[:3, :, :]  # drop 4th channel
49     img = torch.moveaxis(img, 0, -1)
50     img = preprocess_input(img)
51     img = torch.moveaxis(img, -1, 0)
52
53     # case image is not 1280x720
54     H, W = img.shape[-2:]
55     if H != 720 or W != 1280:
56         img = torch.nn.functional.interpolate(
57             img[None, ...], size=(720, 1280), mode="bilinear", align_corners=False
58         )[0]
59         target = torch.nn.functional.interpolate(
60             target[None, None, ...].float(), size=(720, 1280), mode="nearest"
61         )[0, 0].long()
62
63     # size must be divisible by 32, so we pad the image.
64     img = pad(img, [0, 8]).float()
65     target = pad(target, [0, 8])
66     return img, target
67
68
69 def cityscapes_transform(img, target):
70     # resize image and target to 1280,720
71     img = img.resize((1280, 720))
72     # use nearest neighbour interpolation for target
73     target = target.resize((1280, 720), Image.NEAREST)
74     target = to_tensor(target).squeeze(0)
75     target = target = (target * 255).long()
76     return my_transform(img, target)
77
78
79 def eval(dataset_test, detector):
80     metrics = OODMetrics(mode="segmentation", void_label=1)
81     loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
82
83     with torch.no_grad():
84         for n, (x, y) in enumerate(loader):
85             y, x = y.to(device), x.to(device)
86
87             o = detector(x)
88
89             # undo padding
90             o = pad(o, [0, -8])
91             y = pad(y, [0, -8])
92
93             metrics.update(o, y)
94
95     print(metrics.compute())

Setup datasets

101 # Please download Citiscapes Dataset, for example, from https://www.cityscapes-dataset.com/ or https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/download/downloader.py
102 # and put it in the data/cityscapes folder
103 dataset = Cityscapes(
104     root="data/cityscapes",
105     split="train",
106     transforms=cityscapes_transform,
107     mode="fine",
108     target_type="semantic",
109 )
110
111 # Test datasets for RoadAnomaly
112 dataset_test_roadanomaly_original = RoadAnomaly(root="data", transform=my_transform, download=True)
113
114 # Test datasets for SegmentMeIfYouCan
115 dataset_test_SMIYC_RoadAnomaly21 = SegmentMeIfYouCan(
116     root="data", subset="RoadAnomaly21", transform=my_transform, download=True
117 )
118 dataset_test_SMIYC_RoadObstacle21 = SegmentMeIfYouCan(
119     root="data", subset="RoadObstacle21", transform=my_transform, download=True
120 )

Setup model

125 model = smp.FPN(
126     encoder_name="resnet50",
127     encoder_weights="imagenet",
128     in_channels=3,
129     classes=classes,
130 ).to(device)

Train model for some epochs

134 criterion = torch.nn.CrossEntropyLoss()
135 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
136 loader = DataLoader(
137     dataset,
138     batch_size=batch_size,
139     shuffle=True,
140     num_workers=10,
141     worker_init_fn=fix_random_seed,
142     generator=g,
143 )
144
145 ious = []
146 loss_ema = 0
147 ioe_ema = 0
148
149 for epoch in range(num_epochs):
150     for n, (x, y) in enumerate(loader):
151         optimizer.zero_grad()
152         y, x = y.to(device), x.to(device)
153
154         y_hat = model(x)
155         loss = criterion(y_hat, y)
156         loss.backward()
157         optimizer.step()
158
159         tp, fp, fn, tn = smp.metrics.get_stats(
160             y_hat.softmax(dim=1).max(dim=1).indices.long(),
161             y.long(),
162             mode="multiclass",
163             num_classes=classes,
164         )
165         iou = iou_score(tp, fp, fn, tn)
166
167         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
168         ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
169
170         if n % 10 == 0:
171             print(
172                 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
173             )

Evaluate

177 print("Evaluating")
178 model.eval()
179 detector = EnergyBased(model)
180
181 print("RoadAnomaly Original dataset")
182 eval(dataset_test_roadanomaly_original, detector)
183 print("SegmentMeIfYouCan RoadAnomaly21 dataset")
184 eval(dataset_test_SMIYC_RoadAnomaly21, detector)
185 print("SegmentMeIfYouCan RoadObstacle21 dataset")
186 eval(dataset_test_SMIYC_RoadObstacle21, detector)

Output:

Dataset

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

RoadAnomaly Original

Energy

80.90

41.66

97.00

29.76

46.31

SegmentMeIfYouCan RoadAnomaly21

Energy

83.69

40.69

96.31

43.54

47.92

SegmentMeIfYouCan RoadObstacle21

Energy

87.34

35.00

99.97

28.34

20.45

Gallery generated by Sphinx-Gallery