StreetHazards

We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the StreetHazards. We then use the EnergyBased OOD detector.

Note

Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size.

Warning

The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.

Note

The license of the model originally used for the street hazards dataset is not compatible with pytorch-ood. This prevents us from re-using the implementation from the original repository.

23 import segmentation_models_pytorch as smp
24 import torch
25 from segmentation_models_pytorch.encoders import get_preprocessing_fn
26 from segmentation_models_pytorch.metrics import iou_score
27 from torch.utils.data import DataLoader
28 from torchvision.transforms.functional import pad, to_tensor
29
30 from pytorch_ood.dataset.img import StreetHazards
31 from pytorch_ood.detector import EnergyBased
32 from pytorch_ood.utils import OODMetrics, fix_random_seed
33
34 device = "cuda:0"
35 batch_size = 4
36 num_epochs = 1
37
38 fix_random_seed(12345)
39 g = torch.Generator()
40 g.manual_seed(0)

Setup preprocessing

45 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
46
47
48 def my_transform(img, target):
49     img = to_tensor(img)[:3, :, :]  # drop 4th channel
50     img = torch.moveaxis(img, 0, -1)
51     img = preprocess_input(img)
52     img = torch.moveaxis(img, -1, 0)
53
54     # size must be divisible by 32, so we pad the image.
55     img = pad(img, [0, 8]).float()
56     target = pad(target, [0, 8])
57     return img, target

Setup datasets

62 dataset = StreetHazards(root="data", subset="train", transform=my_transform, download=True)
63 dataset_test = StreetHazards(root="data", subset="test", transform=my_transform, download=True)

Setup model

68 model = smp.FPN(
69     encoder_name="resnet50",
70     encoder_weights="imagenet",
71     in_channels=3,
72     classes=13,
73 ).to(device)

Train model for some epochs

 77 criterion = smp.losses.DiceLoss(mode="multiclass")
 78 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
 79 loader = DataLoader(
 80     dataset,
 81     batch_size=batch_size,
 82     shuffle=True,
 83     num_workers=10,
 84     worker_init_fn=fix_random_seed,
 85     generator=g,
 86 )
 87
 88 ious = []
 89 loss_ema = 0
 90 ioe_ema = 0
 91
 92 for epoch in range(num_epochs):
 93     for n, (x, y) in enumerate(loader):
 94         optimizer.zero_grad()
 95         y, x = y.to(device), x.to(device)
 96
 97         y_hat = model(x)
 98         loss = criterion(y_hat, y)
 99         loss.backward()
100         optimizer.step()
101
102         tp, fp, fn, tn = smp.metrics.get_stats(
103             y_hat.softmax(dim=1).max(dim=1).indices.long(),
104             y.long(),
105             mode="multiclass",
106             num_classes=13,
107         )
108         iou = iou_score(tp, fp, fn, tn)
109
110         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
111         ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
112
113         if n % 10 == 0:
114             print(
115                 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
116             )

Evaluate

120 print("Evaluating")
121 model.eval()
122 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
123 detector = EnergyBased(model)
124 metrics = OODMetrics(mode="segmentation")
125
126 with torch.no_grad():
127     for n, (x, y) in enumerate(loader):
128         y, x = y.to(device), x.to(device)
129         o = detector(x)
130
131         # undo padding
132         o = pad(o, [0, -8])
133         y = pad(y, [0, -8])
134
135         metrics.update(o, y)
136
137 print(metrics.compute())

Output:

Dataset

Detector | AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

Streethazards

Energy

81.93

42.28

99.70

09.05

57.43

Gallery generated by Sphinx-Gallery