StreetHazards

We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the StreetHazards. We then use the EnergyBased OOD detector.

Note

Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size.

Warning

The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.

Note

The license of the model originally used for the street hazards dataset is not compatible with pytorch-ood. This prevents us from re-using the implementation from the original repository.

22 import segmentation_models_pytorch as smp
23 import torch
24 from segmentation_models_pytorch.encoders import get_preprocessing_fn
25 from segmentation_models_pytorch.metrics import iou_score
26 from torch.utils.data import DataLoader
27 from torchvision.transforms.functional import pad, to_tensor
28
29 from pytorch_ood.dataset.img import StreetHazards
30 from pytorch_ood.detector import EnergyBased
31 from pytorch_ood.utils import OODMetrics, fix_random_seed
32
33 device = "cuda:0"
34 batch_size = 4
35 num_epochs = 1
36
37 fix_random_seed(12345)
38 g = torch.Generator()
39 g.manual_seed(0)

Setup preprocessing

44 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
45
46
47 def my_transform(img, target):
48     img = to_tensor(img)[:3, :, :]  # drop 4th channel
49     img = torch.moveaxis(img, 0, -1)
50     img = preprocess_input(img)
51     img = torch.moveaxis(img, -1, 0)
52
53     # size must be divisible by 32, so we pad the image.
54     img = pad(img, [0, 8]).float()
55     target = pad(target, [0, 8])
56     return img, target

Setup datasets

61 dataset = StreetHazards(
62     root="data", subset="train", transform=my_transform, download=True
63 )
64 dataset_test = StreetHazards(
65     root="data", subset="test", transform=my_transform, download=True
66 )

Setup model

71 model = smp.FPN(
72     encoder_name="resnet50",
73     encoder_weights="imagenet",
74     in_channels=3,
75     classes=13,
76 ).to(device)

Train model for some epochs

 80 criterion = smp.losses.DiceLoss(mode="multiclass")
 81 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
 82 loader = DataLoader(
 83     dataset,
 84     batch_size=batch_size,
 85     shuffle=True,
 86     num_workers=10,
 87     worker_init_fn=fix_random_seed,
 88     generator=g,
 89 )
 90
 91 ious = []
 92 loss_ema = 0
 93 ioe_ema = 0
 94
 95 for epoch in range(num_epochs):
 96     for n, (x, y) in enumerate(loader):
 97         optimizer.zero_grad()
 98         y, x = y.to(device), x.to(device)
 99
100         y_hat = model(x)
101         loss = criterion(y_hat, y)
102         loss.backward()
103         optimizer.step()
104
105         tp, fp, fn, tn = smp.metrics.get_stats(
106             y_hat.softmax(dim=1).max(dim=1).indices.long(),
107             y.long(),
108             mode="multiclass",
109             num_classes=13,
110         )
111         iou = iou_score(tp, fp, fn, tn)
112
113         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
114         ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
115
116         if n % 10 == 0:
117             print(
118                 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
119             )

Evaluate

123 print("Evaluating")
124 model.eval()
125 loader = DataLoader(
126     dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g
127 )
128 detector = EnergyBased(model)
129 metrics = OODMetrics(mode="segmentation")
130
131 with torch.no_grad():
132     for n, (x, y) in enumerate(loader):
133         y, x = y.to(device), x.to(device)
134         o = detector(x)
135
136         # undo padding
137         o = pad(o, [-8, -8])
138         y = pad(y, [-8, -8])
139
140         metrics.update(o, y)
141
142 print(metrics.compute())
Output:

{‘AUROC’: 0.8069181442260742, ‘AUPR-IN’: 0.07396415621042252, ‘AUPR-OUT’: 0.9966945648193359, ‘FPR95TPR’: 0.7595465183258057}

Gallery generated by Sphinx-Gallery