Note
Go to the end to download the full example code.
StreetHazards with VOS Loss
We train a Feature Pyramid Segmentation model
with a ResNet-50 backbone pre-trained on the ImageNet
on the StreetHazards test set using
the supervised VOSRegLoss.
We then use the VOSBased OOD detector.
This setup is merely made to demonstrate how to train a supervised anomaly segmentation model with this loss function.
Warning
We train on the test set, as it contains examples of anomalies. The results will not be meaningful.
Note
Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size. This loss is more effektive with a scheduler and a lot of epochs.
24 import numpy as np
25 import segmentation_models_pytorch as smp
26 import torch
27 from segmentation_models_pytorch.encoders import get_preprocessing_fn
28 from segmentation_models_pytorch.metrics import iou_score
29 from torch.utils.data import DataLoader
30 from torchvision.transforms.functional import pad, to_tensor
31
32 from pytorch_ood.dataset.img import StreetHazards
33 from pytorch_ood.detector import WeightedEBO
34 from pytorch_ood.loss import VOSRegLoss
35 from pytorch_ood.utils import OODMetrics, fix_random_seed
36
37 device = "cuda:0"
38 batch_size = 4
39 num_epochs = 1
40 lr = 0.0001
41 num_classes = 13
42
43 fix_random_seed(12345)
44 g = torch.Generator()
45 g.manual_seed(0)
Setup preprocessing
50 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
51
52
53 def my_transform(img, target):
54 img = to_tensor(img)[:3, :, :] # drop 4th channel
55 img = torch.moveaxis(img, 0, -1)
56 img = preprocess_input(img)
57 img = torch.moveaxis(img, -1, 0)
58
59 # size must be divisible by 32, so we pad the image.
60 img = pad(img, [0, 8]).float()
61 target = pad(target, [0, 8])
62 return img, target
63
64
65 def cosine_annealing(step, total_steps, lr_max, lr_min):
66 return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))
Setup datasets, train on ood images for demonstration purposes.
71 dataset = StreetHazards(root="data", subset="test", transform=my_transform, download=True)
72 dataset_test = StreetHazards(root="data", subset="test", transform=my_transform, download=True)
Setup model
77 model = smp.FPN(
78 encoder_name="resnet50",
79 encoder_weights="imagenet",
80 in_channels=3,
81 classes=num_classes,
82 ).to(device)
Create neural network functions (layers)
86 phi = torch.nn.Linear(1, 2).to(device)
87 weights_energy = torch.nn.Linear(num_classes, 1).to(device)
88 torch.nn.init.uniform_(weights_energy.weight)
89
90 criterion = VOSRegLoss(phi, weights_energy, device=device)
Train model for some epochs
95 optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
96
97
98 loader = DataLoader(
99 dataset,
100 batch_size=batch_size,
101 shuffle=True,
102 num_workers=10,
103 worker_init_fn=fix_random_seed,
104 generator=g,
105 )
106
107 # setup scheduler for optimizer (recommended)
108 scheduler = torch.optim.lr_scheduler.LambdaLR(
109 optimizer,
110 lr_lambda=lambda step: cosine_annealing(
111 step,
112 num_epochs * len(loader),
113 1, # since lr_lambda computes multiplicative factor
114 1e-6 / lr,
115 ),
116 )
117
118 ious = []
119 loss_ema = 0
120 ioe_ema = 0
121
122 for epoch in range(num_epochs):
123 for n, (x, y) in enumerate(loader):
124 optimizer.zero_grad()
125 y, x = y.to(device), x.to(device)
126
127 y_hat = model(x)
128 loss = criterion(y_hat, y)
129 loss.backward()
130 optimizer.step()
131 scheduler.step()
132
133 tp, fp, fn, tn = smp.metrics.get_stats(
134 y_hat.softmax(dim=1).max(dim=1).indices.long(),
135 y.long(),
136 mode="multiclass",
137 num_classes=13,
138 )
139 iou = iou_score(tp, fp, fn, tn)
140
141 loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
142 ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
143
144 if n % 10 == 0:
145 print(
146 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
147 )
Evaluate
151 print("Evaluating")
152 model.eval()
153 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
154 detector = WeightedEBO(model, weights_energy)
155 metrics = OODMetrics(mode="segmentation")
156
157 with torch.no_grad():
158 for n, (x, y) in enumerate(loader):
159 y, x = y.to(device), x.to(device)
160 o = detector(x)
161
162 # undo padding
163 o = pad(o, [-8, -8])
164 y = pad(y, [-8, -8])
165
166 metrics.update(o, y)
167
168 print(metrics.compute())
Output: {‘AUROC’: 0.9346237778663635, ‘AUPR-IN’: 0.15255042910575867, ‘AUPR-OUT’: 0.9993401169776917, ‘FPR95TPR’: 0.18086743354797363}