"""
OpenOOD v1.5 - ImageNet
========================

Reproduces the OpenOOD v1.5 benchmark for OOD detection on ImageNet, using a pre-trained ResNet 50.

"""

import pandas as pd  # additional dependency, used here for convenience
import torch
from torchvision.models import resnet50
from torchvision.models.resnet import ResNet50_Weights

from pytorch_ood.benchmark import ImageNet_OpenOOD
from pytorch_ood.detector import MaxSoftmax
from pytorch_ood.utils import fix_random_seed

fix_random_seed(123)

device = "cuda:0"
loader_kwargs = {"batch_size": 16, "num_workers": 12}

# %%
model = resnet50(ResNet50_Weights.IMAGENET1K_V1).eval().to(device)
trans = ResNet50_Weights.IMAGENET1K_V1.transforms()

print(trans)

# %%
# If you want to test more detectors, you can just add them here
detectors = {
    "MSP": MaxSoftmax(model),
}

# %%
# The ImageNet root should contain at least the validation tar, the dev kit tar, and the meta.bin
# that is generated by the torchvision ImageNet implementation.
results = []
benchmark = ImageNet_OpenOOD(root="data", image_net_root="data/imagenet-2012/", transform=trans)


with torch.no_grad():
    for detector_name, detector in detectors.items():
        print(f"> Evaluating {detector_name}")
        res = benchmark.evaluate(detector, loader_kwargs=loader_kwargs, device=device)
        for r in res:
            r.update({"Detector": detector_name})
        results += res

df = pd.DataFrame(results)
print((df.set_index(["Dataset", "Detector"]) * 100).to_csv(float_format="%.2f"))

# %%
# This should produce a table with results for the following OOD datasets:
#
# * SSBHard (near-OOD)
# * NINCO (near-OOD)
# * iNaturalist (far-OOD)
# * Textures (far-OOD)
# * OpenImagesO (far-OOD)
#
