Note
Go to the end to download the full example code.
StreetHazards + COCO objects
We train a Feature Pyramid Segmentation model
with a ResNet-50 backbone pre-trained on the ImageNet
on the StreetHazards.
During training, we insert random COCO objects as anomalies into the image to regularize the model.
Warning
The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.
15 import segmentation_models_pytorch as smp
16 import torch
17 from segmentation_models_pytorch.encoders import get_preprocessing_fn
18 from segmentation_models_pytorch.metrics import iou_score
19 from torch.utils.data import DataLoader
20 from torchvision.transforms.functional import pad, to_tensor
21
22 from pytorch_ood.dataset.img import StreetHazards
23 from pytorch_ood.detector import Entropy
24 from pytorch_ood.loss import EntropicOpenSetLoss
25 from pytorch_ood.utils import OODMetrics, fix_random_seed
26 from pytorch_ood.utils.transforms import InsertCOCO
27
28 device = "cuda:0"
29 batch_size = 4
30 num_epochs = 1
31
32 fix_random_seed(12345)
33 g = torch.Generator()
34 g.manual_seed(0)
Setup preprocessing
39 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
40
41 # for demonstration purposes, we set the probability of OOD to 1
42 coco_transform = InsertCOCO(
43 coco_dir="data/coco",
44 exclude_classes="Streethazards",
45 p=1,
46 )
47
48
49 def my_transform(img, target, use_coco_transform):
50 if use_coco_transform:
51 img, target = coco_transform(img, target)
52 img = to_tensor(img)[:3, :, :] # drop 4th channel
53 img = torch.moveaxis(img, 0, -1)
54 img = preprocess_input(img)
55 img = torch.moveaxis(img, -1, 0)
56
57 # size must be divisible by 32, so we pad the image.
58 img = pad(img, [0, 8]).float()
59 target = pad(target, [0, 8])
60 return img, target
Setup datasets
65 dataset = StreetHazards(
66 root="data",
67 subset="train",
68 transform=lambda img, target: my_transform(img, target, True),
69 download=True,
70 )
71 dataset_test = StreetHazards(
72 root="data",
73 subset="test",
74 transform=lambda img, target: my_transform(img, target, False),
75 download=True,
76 )
Setup model
81 model = smp.FPN(
82 encoder_name="resnet50",
83 encoder_weights="imagenet",
84 in_channels=3,
85 classes=13,
86 ).to(device)
Train model for some epochs
90 criterion = EntropicOpenSetLoss()
91 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
92 loader = DataLoader(
93 dataset,
94 batch_size=batch_size,
95 shuffle=True,
96 num_workers=10,
97 worker_init_fn=fix_random_seed,
98 generator=g,
99 )
100
101 ious = []
102 loss_ema = 0
103 ioe_ema = 0
104
105 for epoch in range(num_epochs):
106 for n, (x, y) in enumerate(loader):
107 optimizer.zero_grad()
108 y, x = y.to(device), x.to(device)
109
110 y_hat = model(x)
111 loss = criterion(y_hat, y)
112 loss.backward()
113 optimizer.step()
114
115 tp, fp, fn, tn = smp.metrics.get_stats(
116 y_hat.softmax(dim=1).max(dim=1).indices.long(),
117 y.long(),
118 mode="multiclass",
119 num_classes=13,
120 )
121 iou = iou_score(tp, fp, fn, tn)
122
123 loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
124 ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
125
126 if n % 10 == 0:
127 print(
128 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
129 )
Evaluate
133 print("Evaluating")
134 model.eval()
135 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
136 detector = Entropy(model)
137 metrics = OODMetrics(mode="segmentation")
138
139 with torch.no_grad():
140 for n, (x, y) in enumerate(loader):
141 y, x = y.to(device), x.to(device)
142 o = detector(x)
143
144 # undo padding
145 o = pad(o, [-8, -8])
146 y = pad(y, [-8, -8])
147
148 metrics.update(o, y)
149
150 print(metrics.compute())
Output: {‘AUROC’: 0.9573410749435425, ‘AUPR-IN’: 0.5151191353797913, ‘AUPR-OUT’: 0.9991346001625061, ‘FPR95TPR’: 0.16139476001262665}