StreetHazards with Entropic Loss

We train a Feature Pyramid Segmentation model with a ResNet-50 backbone pre-trained on the ImageNet on the StreetHazards test set using the supervised EntropicOpenSetLoss.

We then use the Entropy OOD detector.

This setup is merely made to demonstrate how to train a supervised anomaly segmentation model with this loss function.

Warning

We train on the test set, as it contains examples of anomalies. The results will not be meaningful.

Note

Training with a batch-size of 4 requires slightly more than 12 GB of GPU memory. However, the models tend to also converge to reasonable performance with a smaller batch-size.

23 import segmentation_models_pytorch as smp
24 import torch
25 from segmentation_models_pytorch.encoders import get_preprocessing_fn
26 from segmentation_models_pytorch.metrics import iou_score
27 from torch.utils.data import DataLoader
28 from torchvision.transforms.functional import pad, to_tensor
29
30 from pytorch_ood.dataset.img import StreetHazards
31 from pytorch_ood.detector import Entropy
32 from pytorch_ood.loss import EntropicOpenSetLoss
33 from pytorch_ood.utils import OODMetrics, fix_random_seed
34
35 device = "cuda:0"
36 batch_size = 4
37 num_epochs = 1
38
39 fix_random_seed(12345)
40 g = torch.Generator()
41 g.manual_seed(0)

Setup preprocessing

46 preprocess_input = get_preprocessing_fn("resnet50", pretrained="imagenet")
47
48
49 def my_transform(img, target):
50     img = to_tensor(img)[:3, :, :]  # drop 4th channel
51     img = torch.moveaxis(img, 0, -1)
52     img = preprocess_input(img)
53     img = torch.moveaxis(img, -1, 0)
54
55     # size must be divisible by 32, so we pad the image.
56     img = pad(img, [0, 8]).float()
57     target = pad(target, [0, 8])
58     return img, target

Setup datasets, train on ood images for demonstration purposes.

63 dataset = StreetHazards(root="data", subset="test", transform=my_transform, download=True)
64 dataset_test = StreetHazards(root="data", subset="test", transform=my_transform, download=True)

Setup model

69 model = smp.FPN(
70     encoder_name="resnet50",
71     encoder_weights="imagenet",
72     in_channels=3,
73     classes=13,
74 ).to(device)

Train model for some epochs

 79 criterion = EntropicOpenSetLoss()
 80 optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
 81 loader = DataLoader(
 82     dataset,
 83     batch_size=batch_size,
 84     shuffle=True,
 85     num_workers=10,
 86     worker_init_fn=fix_random_seed,
 87     generator=g,
 88 )
 89
 90 ious = []
 91 loss_ema = 0
 92 ioe_ema = 0
 93
 94 for epoch in range(num_epochs):
 95     for n, (x, y) in enumerate(loader):
 96         optimizer.zero_grad()
 97         y, x = y.to(device), x.to(device)
 98
 99         y_hat = model(x)
100         loss = criterion(y_hat, y)
101         loss.backward()
102         optimizer.step()
103
104         tp, fp, fn, tn = smp.metrics.get_stats(
105             y_hat.softmax(dim=1).max(dim=1).indices.long(),
106             y.long(),
107             mode="multiclass",
108             num_classes=13,
109         )
110         iou = iou_score(tp, fp, fn, tn)
111
112         loss_ema = 0.8 * loss_ema + 0.2 * loss.item()
113         ioe_ema = 0.8 * ioe_ema + 0.2 * iou.mean().item()
114
115         if n % 10 == 0:
116             print(
117                 f"Epoch {epoch:03d} [{n:05d}/{len(loader):05d}] \t Loss: {loss_ema:02.2f} \t IoU: {ioe_ema:02.2f}"
118             )

Evaluate

122 print("Evaluating")
123 model.eval()
124 loader = DataLoader(dataset_test, batch_size=4, worker_init_fn=fix_random_seed, generator=g)
125 detector = Entropy(model)
126 metrics = OODMetrics(mode="segmentation")
127
128 with torch.no_grad():
129     for n, (x, y) in enumerate(loader):
130         y, x = y.to(device), x.to(device)
131         o = detector(x)
132
133         # undo padding
134         o = pad(o, [-8, -8])
135         y = pad(y, [-8, -8])
136
137         metrics.update(o, y)
138
139 print(metrics.compute())
Output:

{‘AUROC’: 0.9705050587654114, ‘AUPR-IN’: 0.3917403519153595, ‘AUPR-OUT’: 0.9997314214706421, ‘FPR95TPR’: 0.10926716774702072}

Gallery generated by Sphinx-Gallery