Note
Go to the end to download the full example code.
StreetHazards
We train a Feature Pyramid Segmentation model
with a ResNet-50 backbone pre-trained on the ImageNet
on the StreetHazards.
We then use the EnergyBased OOD detector.
Note
Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size.
Warning
The results produced by this script vary. It is impossible to ensure the reproducibility of the exact numerical values at the moment, because the model includes operations for which no deterministic implementation exists at the time of writing.
Note
The license of the model originally used for the street hazards dataset
is not compatible with pytorch-ood. This prevents us from re-using the implementation
from the original repository.
22 import segmentation_models_pytorch as smp
23 import torch
24 from segmentation_models_pytorch.encoders import get_preprocessing_fn
25 from segmentation_models_pytorch.metrics import iou_score
26 from torch.utils.data import DataLoader
27 from torchvision.transforms.functional import pad, to_tensor
28
29 from pytorch_ood.dataset.img import StreetHazards
30 from pytorch_ood.detector import EnergyBased
31 from pytorch_ood.utils import OODMetrics, fix_random_seed
32
33 device = "cuda:0"
34 batch_size = 4
35 num_epochs = 1
36
37 fix_random_seed(12345)
38 g = torch.Generator()
39 g.manual_seed(0)
Setup preprocessing
44 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
45
46
47 def my_transform(img, target):
48 img = to_tensor(img)[:3, :, :] # drop 4th channel
49 img = torch.moveaxis(img, 0, -1)
50 img = preprocess_input(img)
51 img = torch.moveaxis(img, -1, 0)
52
53 # size must be divisible by 32, so we pad the image.
54 img = pad(img, [0, 8]).float()
55 target = pad(target, [0, 8])
56 return img, target
Setup datasets
61 dataset = StreetHazards(
62 root="data", subset="train", transform=my_transform, download=True
63 )
64 dataset_test = StreetHazards(
65 root="data", subset="test", transform=my_transform, download=True
66 )
Setup model
71 model = smp.FPN(
72 encoder_name="resnet50",
73 encoder_weights="imagenet",
74 in_channels=3,
75 classes=13,
76 ).to(device)
Train model for some epochs
80 criterion = smp.losses.DiceLoss(mode="multiclass")
81 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
82 loader = DataLoader(
83 dataset,
84 batch_size=batch_size,
85 shuffle=True,
86 num_workers=10,
87 worker_init_fn=fix_random_seed,
88 generator=g,
89 )
90
91 ious = []
92 loss_ema = 0
93 ioe_ema = 0
94
95 for epoch in range(num_epochs):
96 for n, (x, y) in enumerate(loader):
97 optimizer.zero_grad()
98 y, x = y.to(device), x.to(device)
99
100 y_hat = model(x)
101 loss = criterion(y_hat, y)
102 loss.backward()
103 optimizer.step()
104
105 tp, fp, fn, tn = smp.metrics.get_stats(
106 y_hat.softmax(dim=1).max(dim=1).indices.long(),
107 y.long(),
108 mode="multiclass",
109 num_classes=13,
110 )
111 iou = iou_score(tp, fp, fn, tn)
112
113 loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
114 ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
115
116 if n % 10 == 0:
117 print(
118 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
119 )
Evaluate
123 print("Evaluating")
124 model.eval()
125 loader = DataLoader(
126 dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g
127 )
128 detector = EnergyBased(model)
129 metrics = OODMetrics(mode="segmentation")
130
131 with torch.no_grad():
132 for n, (x, y) in enumerate(loader):
133 y, x = y.to(device), x.to(device)
134 o = detector(x)
135
136 # undo padding
137 o = pad(o, [-8, -8])
138 y = pad(y, [-8, -8])
139
140 metrics.update(o, y)
141
142 print(metrics.compute())
- Output:
{‘AUROC’: 0.8069181442260742, ‘AUPR-IN’: 0.07396415621042252, ‘AUPR-OUT’: 0.9966945648193359, ‘FPR95TPR’: 0.7595465183258057}