StreetHazards with VOS Loss

We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the StreetHazards test set using the supervised VOSRegLoss.

We then use the VOSBased OOD detector.

This setup is merely made to demonstrate how to train a supervised anomaly segmentation model with this loss function.

Warning

We train on the test set, as it contains examples of anomalies. The results will not be meaningful.

Note

Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size. This loss is more effektive with a scheduler and a lot of epochs.

25 import numpy as np
26 import segmentation_models_pytorch as smp
27 import torch
28 from segmentation_models_pytorch.encoders import get_preprocessing_fn
29 from segmentation_models_pytorch.metrics import iou_score
30 from torch.utils.data import DataLoader
31 from torchvision.transforms.functional import pad, to_tensor
32
33 from pytorch_ood.dataset.img import StreetHazards
34 from pytorch_ood.detector import WeightedEBO
35 from pytorch_ood.loss import VOSRegLoss
36 from pytorch_ood.utils import OODMetrics, fix_random_seed
37
38 device = "cuda:0"
39 batch_size = 4
40 num_epochs = 1
41 lr = 0.0001
42 num_classes = 13
43
44 fix_random_seed(12345)
45 g = torch.Generator()
46 g.manual_seed(0)

Setup preprocessing

51 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
52
53
54 def my_transform(img, target):
55     img = to_tensor(img)[:3, :, :]  # drop 4th channel
56     img = torch.moveaxis(img, 0, -1)
57     img = preprocess_input(img)
58     img = torch.moveaxis(img, -1, 0)
59
60     # size must be divisible by 32, so we pad the image.
61     img = pad(img, [0, 8]).float()
62     target = pad(target, [0, 8])
63     return img, target
64
65
66 def cosine_annealing(step, total_steps, lr_max, lr_min):
67     return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))

Setup datasets, train on ood images for demonstration purposes.

72 dataset = StreetHazards(root="data", subset="test", transform=my_transform, download=True)
73 dataset_test = StreetHazards(root="data", subset="test", transform=my_transform, download=True)

Setup model

78 model = smp.FPN(
79     encoder_name="resnet50",
80     encoder_weights="imagenet",
81     in_channels=3,
82     classes=num_classes,
83 ).to(device)

Create neural network functions (layers)

87 phi = torch.nn.Linear(1, 2).to(device)
88 weights_energy = torch.nn.Linear(num_classes, 1).to(device)
89 torch.nn.init.uniform_(weights_energy.weight)
90
91 criterion = VOSRegLoss(phi, weights_energy, device=device)

Train model for some epochs

 96 optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
 97
 98
 99 loader = DataLoader(
100     dataset,
101     batch_size=batch_size,
102     shuffle=True,
103     num_workers=10,
104     worker_init_fn=fix_random_seed,
105     generator=g,
106 )
107
108 # setup scheduler for optimizer (recommended)
109 scheduler = torch.optim.lr_scheduler.LambdaLR(
110     optimizer,
111     lr_lambda=lambda step: cosine_annealing(
112         step,
113         num_epochs * len(loader),
114         1,  # since lr_lambda computes multiplicative factor
115         1e-6 / lr,
116     ),
117 )
118
119 ious = []
120 loss_ema = 0
121 ioe_ema = 0
122
123 for epoch in range(num_epochs):
124     for n, (x, y) in enumerate(loader):
125         optimizer.zero_grad()
126         y, x = y.to(device), x.to(device)
127
128         y_hat = model(x)
129         loss = criterion(y_hat, y)
130         loss.backward()
131         optimizer.step()
132         scheduler.step()
133
134         tp, fp, fn, tn = smp.metrics.get_stats(
135             y_hat.softmax(dim=1).max(dim=1).indices.long(),
136             y.long(),
137             mode="multiclass",
138             num_classes=13,
139         )
140         iou = iou_score(tp, fp, fn, tn)
141
142         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
143         ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
144
145         if n % 10 == 0:
146             print(
147                 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
148             )

Evaluate

152 print("Evaluating")
153 model.eval()
154 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
155 detector = WeightedEBO(model, weights_energy)
156 metrics = OODMetrics(mode="segmentation")
157
158 with torch.no_grad():
159     for n, (x, y) in enumerate(loader):
160         y, x = y.to(device), x.to(device)
161         o = detector(x)
162
163         # undo padding
164         o = pad(o, [0, -8])
165         y = pad(y, [0, -8])
166
167         metrics.update(o, y)
168
169 print(metrics.compute())

Output:

Dataset

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

Streethazards+VOS-Loss

WeightedEBO

93.56

36.51

99.94

15.45

17.98

Gallery generated by Sphinx-Gallery