StreetHazards + COCO objects

We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the StreetHazards. During training, we insert random COCO objects as anomalies into the image to regularize the model.

Warning

The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.

16 import segmentation_models_pytorch as smp
17 import torch
18 from segmentation_models_pytorch.encoders import get_preprocessing_fn
19 from segmentation_models_pytorch.metrics import iou_score
20 from torch.utils.data import DataLoader
21 from torchvision.transforms.functional import pad, to_tensor
22
23 from pytorch_ood.dataset.img import StreetHazards
24 from pytorch_ood.detector import Entropy
25 from pytorch_ood.loss import EntropicOpenSetLoss
26 from pytorch_ood.utils import OODMetrics, fix_random_seed
27 from pytorch_ood.utils.transforms import InsertCOCO
28
29 device = "cuda:0"
30 batch_size = 4
31 num_epochs = 1
32
33 fix_random_seed(12345)
34 g = torch.Generator()
35 g.manual_seed(0)

Setup preprocessing

40 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
41
42 # for demonstration purposes, we set the probability of OOD to 1
43 coco_transform = InsertCOCO(
44     coco_dir="data/coco",
45     exclude_classes="Streethazards",
46     p=1,
47 )
48
49
50 def my_transform(img, target, use_coco_transform):
51     if use_coco_transform:
52         img, target = coco_transform(img, target)
53     img = to_tensor(img)[:3, :, :]  # drop 4th channel
54     img = torch.moveaxis(img, 0, -1)
55     img = preprocess_input(img)
56     img = torch.moveaxis(img, -1, 0)
57
58     # size must be divisible by 32, so we pad the image.
59     img = pad(img, [0, 8]).float()
60     target = pad(target, [0, 8])
61     return img, target

Setup datasets

66 dataset = StreetHazards(
67     root="data",
68     subset="train",
69     transform=lambda img, target: my_transform(img, target, True),
70     download=True,
71 )
72 dataset_test = StreetHazards(
73     root="data",
74     subset="test",
75     transform=lambda img, target: my_transform(img, target, False),
76     download=True,
77 )

Setup model

82 model = smp.FPN(
83     encoder_name="resnet50",
84     encoder_weights="imagenet",
85     in_channels=3,
86     classes=13,
87 ).to(device)

Train model for some epochs

 91 criterion = EntropicOpenSetLoss()
 92 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
 93 loader = DataLoader(
 94     dataset,
 95     batch_size=batch_size,
 96     shuffle=True,
 97     num_workers=10,
 98     worker_init_fn=fix_random_seed,
 99     generator=g,
100 )
101
102 ious = []
103 loss_ema = 0
104 ioe_ema = 0
105
106 for epoch in range(num_epochs):
107     for n, (x, y) in enumerate(loader):
108         optimizer.zero_grad()
109         y, x = y.to(device), x.to(device)
110
111         y_hat = model(x)
112         loss = criterion(y_hat, y)
113         loss.backward()
114         optimizer.step()
115
116         tp, fp, fn, tn = smp.metrics.get_stats(
117             y_hat.softmax(dim=1).max(dim=1).indices.long(),
118             y.long(),
119             mode="multiclass",
120             num_classes=13,
121         )
122         iou = iou_score(tp, fp, fn, tn)
123
124         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
125         ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
126
127         if n % 10 == 0:
128             print(
129                 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
130             )

Evaluate

134 print("Evaluating")
135 model.eval()
136 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
137 detector = Entropy(model)
138 metrics = OODMetrics(mode="segmentation")
139
140 with torch.no_grad():
141     for n, (x, y) in enumerate(loader):
142         y, x = y.to(device), x.to(device)
143         o = detector(x)
144
145         # undo padding
146         o = pad(o, [0, -8])
147         y = pad(y, [0, -8])
148
149         metrics.update(o, y)
150
151 print(metrics.compute())

Output:

Dataset

Detector

AUROC

AUTC

AUPR-IN | AUPR-OUT

FPR95TPR

StreetHazards+COCO

Entropy

93.88

25.26

99.93

19.43

19.14

Gallery generated by Sphinx-Gallery