Note
Go to the end to download the full example code.
RoadAnomaly
We train a Feature Pyramid Segmentation model
with a ResNet-50 backbone pre-trained on the ImageNet
on the Citiscapes Dataset (please download it before and put gtFine and leftImg8bit it into the data/cityscapes folder).
This model is evaluated using the EnergyBased OOD detector on the original RoadAnomaly dataset and both datasets of the SegmentMeIfYouCan benchmark: RoadAnomaly21 and RoadObstacles21.
Note
Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size.
Warning
The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.
20 import segmentation_models_pytorch as smp
21 import torch
22 from segmentation_models_pytorch.encoders import get_preprocessing_fn
23 from segmentation_models_pytorch.metrics import iou_score
24 from torch.utils.data import DataLoader
25 from torchvision.transforms.functional import pad, to_tensor
26 from torchvision.datasets import Cityscapes
27 from PIL import Image
28
29 from pytorch_ood.dataset.img import RoadAnomaly, SegmentMeIfYouCan
30 from pytorch_ood.detector import EnergyBased
31 from pytorch_ood.utils import OODMetrics, fix_random_seed
32
33 device = "cuda:0"
34 batch_size = 4
35 num_epochs = 1
36 classes = 34
37 fix_random_seed(12345)
38 g = torch.Generator()
39 g.manual_seed(0)
Setup preprocessing
44 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
45
46
47 def my_transform(img, target):
48 img = to_tensor(img)[:3, :, :] # drop 4th channel
49 img = torch.moveaxis(img, 0, -1)
50 img = preprocess_input(img)
51 img = torch.moveaxis(img, -1, 0)
52
53 # case image is not 1280x720
54 H, W = img.shape[-2:]
55 if H != 720 or W != 1280:
56 img = torch.nn.functional.interpolate(
57 img[None, ...], size=(720, 1280), mode="bilinear", align_corners=False
58 )[0]
59 target = torch.nn.functional.interpolate(
60 target[None, None, ...].float(), size=(720, 1280), mode="nearest"
61 )[0, 0].long()
62
63 # size must be divisible by 32, so we pad the image.
64 img = pad(img, [0, 8]).float()
65 target = pad(target, [0, 8])
66 return img, target
67
68
69 def cityscapes_transform(img, target):
70 # resize image and target to 1280,720
71 img = img.resize((1280, 720))
72 # use nearest neighbour interpolation for target
73 target = target.resize((1280, 720), Image.NEAREST)
74 target = to_tensor(target).squeeze(0)
75 target = target = (target * 255).long()
76 return my_transform(img, target)
77
78
79 def eval(dataset_test, detector):
80 metrics = OODMetrics(mode="segmentation", void_label=1)
81 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
82
83 with torch.no_grad():
84 for n, (x, y) in enumerate(loader):
85 y, x = y.to(device), x.to(device)
86
87 o = detector(x)
88
89 # undo padding
90 o = pad(o, [0, -8])
91 y = pad(y, [0, -8])
92
93 metrics.update(o, y)
94
95 print(metrics.compute())
Setup datasets
101 # Please download Citiscapes Dataset, for example, from https://www.cityscapes-dataset.com/ or https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/download/downloader.py
102 # and put it in the data/cityscapes folder
103 dataset = Cityscapes(
104 root="data/cityscapes",
105 split="train",
106 transforms=cityscapes_transform,
107 mode="fine",
108 target_type="semantic",
109 )
110
111 # Test datasets for RoadAnomaly
112 dataset_test_roadanomaly_original = RoadAnomaly(root="data", transform=my_transform, download=True)
113
114 # Test datasets for SegmentMeIfYouCan
115 dataset_test_SMIYC_RoadAnomaly21 = SegmentMeIfYouCan(
116 root="data", subset="RoadAnomaly21", transform=my_transform, download=True
117 )
118 dataset_test_SMIYC_RoadObstacle21 = SegmentMeIfYouCan(
119 root="data", subset="RoadObstacle21", transform=my_transform, download=True
120 )
Setup model
125 model = smp.FPN(
126 encoder_name="resnet50",
127 encoder_weights="imagenet",
128 in_channels=3,
129 classes=classes,
130 ).to(device)
Train model for some epochs
134 criterion = torch.nn.CrossEntropyLoss()
135 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
136 loader = DataLoader(
137 dataset,
138 batch_size=batch_size,
139 shuffle=True,
140 num_workers=10,
141 worker_init_fn=fix_random_seed,
142 generator=g,
143 )
144
145 ious = []
146 loss_ema = 0
147 ioe_ema = 0
148
149 for epoch in range(num_epochs):
150 for n, (x, y) in enumerate(loader):
151 optimizer.zero_grad()
152 y, x = y.to(device), x.to(device)
153
154 y_hat = model(x)
155 loss = criterion(y_hat, y)
156 loss.backward()
157 optimizer.step()
158
159 tp, fp, fn, tn = smp.metrics.get_stats(
160 y_hat.softmax(dim=1).max(dim=1).indices.long(),
161 y.long(),
162 mode="multiclass",
163 num_classes=classes,
164 )
165 iou = iou_score(tp, fp, fn, tn)
166
167 loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
168 ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
169
170 if n % 10 == 0:
171 print(
172 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
173 )
Evaluate
177 print("Evaluating")
178 model.eval()
179 detector = EnergyBased(model)
180
181 print("RoadAnomaly Original dataset")
182 eval(dataset_test_roadanomaly_original, detector)
183 print("SegmentMeIfYouCan RoadAnomaly21 dataset")
184 eval(dataset_test_SMIYC_RoadAnomaly21, detector)
185 print("SegmentMeIfYouCan RoadObstacle21 dataset")
186 eval(dataset_test_SMIYC_RoadObstacle21, detector)
Output:
Dataset |
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|---|
RoadAnomaly Original |
Energy |
80.90 |
41.66 |
97.00 |
29.76 |
46.31 |
SegmentMeIfYouCan RoadAnomaly21 |
Energy |
83.69 |
40.69 |
96.31 |
43.54 |
47.92 |
SegmentMeIfYouCan RoadObstacle21 |
Energy |
87.34 |
35.00 |
99.97 |
28.34 |
20.45 |