Note
Go to the end to download the full example code.
StreetHazards with VOS Loss
We train a Feature Pyramid Segmentation model
with a ResNet-50 backbone pre-trained on the ImageNet
on the StreetHazards test set using
the supervised VOSRegLoss.
We then use the VOSBased OOD detector.
This setup is merely made to demonstrate how to train a supervised anomaly segmentation model with this loss function.
Warning
We train on the test set, as it contains examples of anomalies. The results will not be meaningful.
Note
Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size. This loss is more effektive with a scheduler and a lot of epochs.
25 import numpy as np
26 import segmentation_models_pytorch as smp
27 import torch
28 from segmentation_models_pytorch.encoders import get_preprocessing_fn
29 from segmentation_models_pytorch.metrics import iou_score
30 from torch.utils.data import DataLoader
31 from torchvision.transforms.functional import pad, to_tensor
32
33 from pytorch_ood.dataset.img import StreetHazards
34 from pytorch_ood.detector import WeightedEBO
35 from pytorch_ood.loss import VOSRegLoss
36 from pytorch_ood.utils import OODMetrics, fix_random_seed
37
38 device = "cuda:0"
39 batch_size = 4
40 num_epochs = 1
41 lr = 0.0001
42 num_classes = 13
43
44 fix_random_seed(12345)
45 g = torch.Generator()
46 g.manual_seed(0)
Setup preprocessing
51 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
52
53
54 def my_transform(img, target):
55 img = to_tensor(img)[:3, :, :] # drop 4th channel
56 img = torch.moveaxis(img, 0, -1)
57 img = preprocess_input(img)
58 img = torch.moveaxis(img, -1, 0)
59
60 # size must be divisible by 32, so we pad the image.
61 img = pad(img, [0, 8]).float()
62 target = pad(target, [0, 8])
63 return img, target
64
65
66 def cosine_annealing(step, total_steps, lr_max, lr_min):
67 return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))
Setup datasets, train on ood images for demonstration purposes.
72 dataset = StreetHazards(root="data", subset="test", transform=my_transform, download=True)
73 dataset_test = StreetHazards(root="data", subset="test", transform=my_transform, download=True)
Setup model
78 model = smp.FPN(
79 encoder_name="resnet50",
80 encoder_weights="imagenet",
81 in_channels=3,
82 classes=num_classes,
83 ).to(device)
Create neural network functions (layers)
87 phi = torch.nn.Linear(1, 2).to(device)
88 weights_energy = torch.nn.Linear(num_classes, 1).to(device)
89 torch.nn.init.uniform_(weights_energy.weight)
90
91 criterion = VOSRegLoss(phi, weights_energy, device=device)
Train model for some epochs
96 optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
97
98
99 loader = DataLoader(
100 dataset,
101 batch_size=batch_size,
102 shuffle=True,
103 num_workers=10,
104 worker_init_fn=fix_random_seed,
105 generator=g,
106 )
107
108 # setup scheduler for optimizer (recommended)
109 scheduler = torch.optim.lr_scheduler.LambdaLR(
110 optimizer,
111 lr_lambda=lambda step: cosine_annealing(
112 step,
113 num_epochs * len(loader),
114 1, # since lr_lambda computes multiplicative factor
115 1e-6 / lr,
116 ),
117 )
118
119 ious = []
120 loss_ema = 0
121 ioe_ema = 0
122
123 for epoch in range(num_epochs):
124 for n, (x, y) in enumerate(loader):
125 optimizer.zero_grad()
126 y, x = y.to(device), x.to(device)
127
128 y_hat = model(x)
129 loss = criterion(y_hat, y)
130 loss.backward()
131 optimizer.step()
132 scheduler.step()
133
134 tp, fp, fn, tn = smp.metrics.get_stats(
135 y_hat.softmax(dim=1).max(dim=1).indices.long(),
136 y.long(),
137 mode="multiclass",
138 num_classes=13,
139 )
140 iou = iou_score(tp, fp, fn, tn)
141
142 loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
143 ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
144
145 if n % 10 == 0:
146 print(
147 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
148 )
Evaluate
152 print("Evaluating")
153 model.eval()
154 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
155 detector = WeightedEBO(model, weights_energy)
156 metrics = OODMetrics(mode="segmentation")
157
158 with torch.no_grad():
159 for n, (x, y) in enumerate(loader):
160 y, x = y.to(device), x.to(device)
161 o = detector(x)
162
163 # undo padding
164 o = pad(o, [0, -8])
165 y = pad(y, [0, -8])
166
167 metrics.update(o, y)
168
169 print(metrics.compute())
Output:
Dataset |
Detector |
AUROC |
AUTC |
AUPR-IN |
AUPR-OUT |
FPR95TPR |
|---|---|---|---|---|---|---|
Streethazards+VOS-Loss |
WeightedEBO |
93.56 |
36.51 |
99.94 |
15.45 |
17.98 |