StreetHazards with Entropic Loss

We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the StreetHazards test set using the supervised EntropicOpenSetLoss.

We then use the Entropy OOD detector.

This setup is merely made to demonstrate how to train a supervised anomaly segmentation model with this loss function.

Warning

We train on the test set, as it contains examples of anomalies. The results will not be meaningful.

Note

Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size.

24 import segmentation_models_pytorch as smp
25 import torch
26 from segmentation_models_pytorch.encoders import get_preprocessing_fn
27 from segmentation_models_pytorch.metrics import iou_score
28 from torch.utils.data import DataLoader
29 from torchvision.transforms.functional import pad, to_tensor
30
31 from pytorch_ood.dataset.img import StreetHazards
32 from pytorch_ood.detector import Entropy
33 from pytorch_ood.loss import EntropicOpenSetLoss
34 from pytorch_ood.utils import OODMetrics, fix_random_seed
35
36 device = "cuda:0"
37 batch_size = 4
38 num_epochs = 1
39
40 fix_random_seed(12345)
41 g = torch.Generator()
42 g.manual_seed(0)

Setup preprocessing

47 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
48
49
50 def my_transform(img, target):
51     img = to_tensor(img)[:3, :, :]  # drop 4th channel
52     img = torch.moveaxis(img, 0, -1)
53     img = preprocess_input(img)
54     img = torch.moveaxis(img, -1, 0)
55
56     # size must be divisible by 32, so we pad the image.
57     img = pad(img, [0, 8]).float()
58     target = pad(target, [0, 8])
59     return img, target

Setup datasets, train on ood images for demonstration purposes.

64 dataset = StreetHazards(root="data", subset="test", transform=my_transform, download=True)
65 dataset_test = StreetHazards(root="data", subset="test", transform=my_transform, download=True)

Setup model

70 model = smp.FPN(
71     encoder_name="resnet50",
72     encoder_weights="imagenet",
73     in_channels=3,
74     classes=13,
75 ).to(device)

Train model for some epochs

 80 criterion = EntropicOpenSetLoss()
 81 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
 82 loader = DataLoader(
 83     dataset,
 84     batch_size=batch_size,
 85     shuffle=True,
 86     num_workers=10,
 87     worker_init_fn=fix_random_seed,
 88     generator=g,
 89 )
 90
 91 ious = []
 92 loss_ema = 0
 93 ioe_ema = 0
 94
 95 for epoch in range(num_epochs):
 96     for n, (x, y) in enumerate(loader):
 97         optimizer.zero_grad()
 98         y, x = y.to(device), x.to(device)
 99
100         y_hat = model(x)
101         loss = criterion(y_hat, y)
102         loss.backward()
103         optimizer.step()
104
105         tp, fp, fn, tn = smp.metrics.get_stats(
106             y_hat.softmax(dim=1).max(dim=1).indices.long(),
107             y.long(),
108             mode="multiclass",
109             num_classes=13,
110         )
111         iou = iou_score(tp, fp, fn, tn)
112
113         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
114         ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
115
116         if n % 10 == 0:
117             print(
118                 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
119             )

Evaluate

123 print("Evaluating")
124 model.eval()
125 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
126 detector = Entropy(model)
127 metrics = OODMetrics(mode="segmentation")
128
129 with torch.no_grad():
130     for n, (x, y) in enumerate(loader):
131         y, x = y.to(device), x.to(device)
132         o = detector(x)
133
134         # undo padding
135         o = pad(o, [0, -8])
136         y = pad(y, [0, -8])
137
138         metrics.update(o, y)
139
140 print(metrics.compute())

Output:

Dataset

Detector

AUROC

AUTC

AUPR-IN

AUPR-OUT

FPR95TPR

Streethazards+Entropic-Loss

Entropy

97.23

18.69

99.98

39.51

10.43

Gallery generated by Sphinx-Gallery