Note
Go to the end to download the full example code.
StreetHazards + COCO objects
We train a Feature Pyramid Segmentation model
with a ResNet-50 backbone pre-trained on the ImageNet
on the StreetHazards.
During training, we insert random COCO objects as anomalies into the image to regularize the model.
Warning
The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.
16 import segmentation_models_pytorch as smp
17 import torch
18 from segmentation_models_pytorch.encoders import get_preprocessing_fn
19 from segmentation_models_pytorch.metrics import iou_score
20 from torch.utils.data import DataLoader
21 from torchvision.transforms.functional import pad, to_tensor
22
23 from pytorch_ood.dataset.img import StreetHazards
24 from pytorch_ood.detector import Entropy
25 from pytorch_ood.loss import EntropicOpenSetLoss
26 from pytorch_ood.utils import OODMetrics, fix_random_seed
27 from pytorch_ood.utils.transforms import InsertCOCO
28
29 device = "cuda:0"
30 batch_size = 4
31 num_epochs = 1
32
33 fix_random_seed(12345)
34 g = torch.Generator()
35 g.manual_seed(0)
Setup preprocessing
40 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
41
42 # for demonstration purposes, we set the probability of OOD to 1
43 coco_transform = InsertCOCO(
44 coco_dir="data/coco",
45 exclude_classes="Streethazards",
46 p=1,
47 )
48
49
50 def my_transform(img, target, use_coco_transform):
51 if use_coco_transform:
52 img, target = coco_transform(img, target)
53 img = to_tensor(img)[:3, :, :] # drop 4th channel
54 img = torch.moveaxis(img, 0, -1)
55 img = preprocess_input(img)
56 img = torch.moveaxis(img, -1, 0)
57
58 # size must be divisible by 32, so we pad the image.
59 img = pad(img, [0, 8]).float()
60 target = pad(target, [0, 8])
61 return img, target
Setup datasets
66 dataset = StreetHazards(
67 root="data",
68 subset="train",
69 transform=lambda img, target: my_transform(img, target, True),
70 download=True,
71 )
72 dataset_test = StreetHazards(
73 root="data",
74 subset="test",
75 transform=lambda img, target: my_transform(img, target, False),
76 download=True,
77 )
Setup model
82 model = smp.FPN(
83 encoder_name="resnet50",
84 encoder_weights="imagenet",
85 in_channels=3,
86 classes=13,
87 ).to(device)
Train model for some epochs
91 criterion = EntropicOpenSetLoss()
92 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
93 loader = DataLoader(
94 dataset,
95 batch_size=batch_size,
96 shuffle=True,
97 num_workers=10,
98 worker_init_fn=fix_random_seed,
99 generator=g,
100 )
101
102 ious = []
103 loss_ema = 0
104 ioe_ema = 0
105
106 for epoch in range(num_epochs):
107 for n, (x, y) in enumerate(loader):
108 optimizer.zero_grad()
109 y, x = y.to(device), x.to(device)
110
111 y_hat = model(x)
112 loss = criterion(y_hat, y)
113 loss.backward()
114 optimizer.step()
115
116 tp, fp, fn, tn = smp.metrics.get_stats(
117 y_hat.softmax(dim=1).max(dim=1).indices.long(),
118 y.long(),
119 mode="multiclass",
120 num_classes=13,
121 )
122 iou = iou_score(tp, fp, fn, tn)
123
124 loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
125 ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
126
127 if n % 10 == 0:
128 print(
129 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
130 )
Evaluate
134 print("Evaluating")
135 model.eval()
136 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
137 detector = Entropy(model)
138 metrics = OODMetrics(mode="segmentation")
139
140 with torch.no_grad():
141 for n, (x, y) in enumerate(loader):
142 y, x = y.to(device), x.to(device)
143 o = detector(x)
144
145 # undo padding
146 o = pad(o, [0, -8])
147 y = pad(y, [0, -8])
148
149 metrics.update(o, y)
150
151 print(metrics.compute())
Output:
Dataset |
Detector |
AUROC |
AUTC |
AUPR-IN | AUPR-OUT |
FPR95TPR |
||
|---|---|---|---|---|---|---|---|
StreetHazards+COCO |
Entropy |
93.88 |
25.26 |
99.93 |
19.43 |
19.14 |
|