Note
Go to the end to download the full example code.
StreetHazards
We train a Feature Pyramid Segmentation model
with a ResNet-50 backbone pre-trained on the ImageNet
on the StreetHazards.
We then use the EnergyBased OOD detector.
Note
Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size.
Warning
The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.
Note
The license of the model originally used for the street hazards dataset
is not compatible with pytorch-ood. This prevents us from re-using the implementation
from the original repository.
23 import segmentation_models_pytorch as smp
24 import torch
25 from segmentation_models_pytorch.encoders import get_preprocessing_fn
26 from segmentation_models_pytorch.metrics import iou_score
27 from torch.utils.data import DataLoader
28 from torchvision.transforms.functional import pad, to_tensor
29
30 from pytorch_ood.dataset.img import StreetHazards
31 from pytorch_ood.detector import EnergyBased
32 from pytorch_ood.utils import OODMetrics, fix_random_seed
33
34 device = "cuda:0"
35 batch_size = 4
36 num_epochs = 1
37
38 fix_random_seed(12345)
39 g = torch.Generator()
40 g.manual_seed(0)
Setup preprocessing
45 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
46
47
48 def my_transform(img, target):
49 img = to_tensor(img)[:3, :, :] # drop 4th channel
50 img = torch.moveaxis(img, 0, -1)
51 img = preprocess_input(img)
52 img = torch.moveaxis(img, -1, 0)
53
54 # size must be divisible by 32, so we pad the image.
55 img = pad(img, [0, 8]).float()
56 target = pad(target, [0, 8])
57 return img, target
Setup datasets
62 dataset = StreetHazards(root="data", subset="train", transform=my_transform, download=True)
63 dataset_test = StreetHazards(root="data", subset="test", transform=my_transform, download=True)
Setup model
68 model = smp.FPN(
69 encoder_name="resnet50",
70 encoder_weights="imagenet",
71 in_channels=3,
72 classes=13,
73 ).to(device)
Train model for some epochs
77 criterion = smp.losses.DiceLoss(mode="multiclass")
78 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
79 loader = DataLoader(
80 dataset,
81 batch_size=batch_size,
82 shuffle=True,
83 num_workers=10,
84 worker_init_fn=fix_random_seed,
85 generator=g,
86 )
87
88 ious = []
89 loss_ema = 0
90 ioe_ema = 0
91
92 for epoch in range(num_epochs):
93 for n, (x, y) in enumerate(loader):
94 optimizer.zero_grad()
95 y, x = y.to(device), x.to(device)
96
97 y_hat = model(x)
98 loss = criterion(y_hat, y)
99 loss.backward()
100 optimizer.step()
101
102 tp, fp, fn, tn = smp.metrics.get_stats(
103 y_hat.softmax(dim=1).max(dim=1).indices.long(),
104 y.long(),
105 mode="multiclass",
106 num_classes=13,
107 )
108 iou = iou_score(tp, fp, fn, tn)
109
110 loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
111 ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
112
113 if n % 10 == 0:
114 print(
115 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
116 )
Evaluate
120 print("Evaluating")
121 model.eval()
122 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
123 detector = EnergyBased(model)
124 metrics = OODMetrics(mode="segmentation")
125
126 with torch.no_grad():
127 for n, (x, y) in enumerate(loader):
128 y, x = y.to(device), x.to(device)
129 o = detector(x)
130
131 # undo padding
132 o = pad(o, [0, -8])
133 y = pad(y, [0, -8])
134
135 metrics.update(o, y)
136
137 print(metrics.compute())
Output:
Dataset |
Detector | AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
||
|---|---|---|---|---|---|---|---|
Streethazards |
Energy |
81.93 |
42.28 |
99.70 |
09.05 |
57.43 |
|