StreetHazards + COCO objects

We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the StreetHazards. During training, we insert random COCO objects as anomalies into the image to regularize the model.

Warning

The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.

15 import segmentation_models_pytorch as smp
16 import torch
17 from segmentation_models_pytorch.encoders import get_preprocessing_fn
18 from segmentation_models_pytorch.metrics import iou_score
19 from torch.utils.data import DataLoader
20 from torchvision.transforms.functional import pad, to_tensor
21
22 from pytorch_ood.dataset.img import StreetHazards
23 from pytorch_ood.detector import Entropy
24 from pytorch_ood.loss import EntropicOpenSetLoss
25 from pytorch_ood.utils import OODMetrics, fix_random_seed
26 from pytorch_ood.utils.transforms import InsertCOCO
27
28 device = "cuda:0"
29 batch_size = 4
30 num_epochs = 1
31
32 fix_random_seed(12345)
33 g = torch.Generator()
34 g.manual_seed(0)

Setup preprocessing

39 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
40
41 # for demonstration purposes, we set the probability of OOD to 1
42 coco_transform = InsertCOCO(
43     coco_dir="data/coco",
44     exclude_classes="Streethazards",
45     p=1,
46 )
47
48
49 def my_transform(img, target, use_coco_transform):
50     if use_coco_transform:
51         img, target = coco_transform(img, target)
52     img = to_tensor(img)[:3, :, :]  # drop 4th channel
53     img = torch.moveaxis(img, 0, -1)
54     img = preprocess_input(img)
55     img = torch.moveaxis(img, -1, 0)
56
57     # size must be divisible by 32, so we pad the image.
58     img = pad(img, [0, 8]).float()
59     target = pad(target, [0, 8])
60     return img, target

Setup datasets

65 dataset = StreetHazards(
66     root="data",
67     subset="train",
68     transform=lambda img, target: my_transform(img, target, True),
69     download=True,
70 )
71 dataset_test = StreetHazards(
72     root="data",
73     subset="test",
74     transform=lambda img, target: my_transform(img, target, False),
75     download=True,
76 )

Setup model

81 model = smp.FPN(
82     encoder_name="resnet50",
83     encoder_weights="imagenet",
84     in_channels=3,
85     classes=13,
86 ).to(device)

Train model for some epochs

 90 criterion = EntropicOpenSetLoss()
 91 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
 92 loader = DataLoader(
 93     dataset,
 94     batch_size=batch_size,
 95     shuffle=True,
 96     num_workers=10,
 97     worker_init_fn=fix_random_seed,
 98     generator=g,
 99 )
100
101 ious = []
102 loss_ema = 0
103 ioe_ema = 0
104
105 for epoch in range(num_epochs):
106     for n, (x, y) in enumerate(loader):
107         optimizer.zero_grad()
108         y, x = y.to(device), x.to(device)
109
110         y_hat = model(x)
111         loss = criterion(y_hat, y)
112         loss.backward()
113         optimizer.step()
114
115         tp, fp, fn, tn = smp.metrics.get_stats(
116             y_hat.softmax(dim=1).max(dim=1).indices.long(),
117             y.long(),
118             mode="multiclass",
119             num_classes=13,
120         )
121         iou = iou_score(tp, fp, fn, tn)
122
123         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
124         ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
125
126         if n % 10 == 0:
127             print(
128                 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
129             )

Evaluate

133 print("Evaluating")
134 model.eval()
135 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
136 detector = Entropy(model)
137 metrics = OODMetrics(mode="segmentation")
138
139 with torch.no_grad():
140     for n, (x, y) in enumerate(loader):
141         y, x = y.to(device), x.to(device)
142         o = detector(x)
143
144         # undo padding
145         o = pad(o, [-8, -8])
146         y = pad(y, [-8, -8])
147
148         metrics.update(o, y)
149
150 print(metrics.compute())

Output: {‘AUROC’: 0.9573410749435425, ‘AUPR-IN’: 0.5151191353797913, ‘AUPR-OUT’: 0.9991346001625061, ‘FPR95TPR’: 0.16139476001262665}

Gallery generated by Sphinx-Gallery