"""
Script to run the prediction on FGVC aircraft dataset with a resnet18 classifier
It requires the two other files downloaded on the same folder:
    * the list : images_family_infer.csv
    * the model : resnet18_level0.ckpt
Use of pyton >=3.11 is highly recommended
With uv (https://docs.astral.sh/uv/getting-started/installation/#installation-methods)
```
    uv python install 3.13
    uv python pin 3.13
    uv add jsonargparse lightning pandas torchvision scipy
    uv run infer.py
```
With pip 
```
    python -m venv venv
    source venv/bin/activate
    python -m pip install jsonargparse lightning pandas torchvision scipy
    python infer.py
```
"""


import torch
import pandas as pd
from tqdm import tqdm
from torchvision import datasets
import torchvision.models as torchvision_models
from torch.utils.data import Dataset
import torchvision.transforms as transforms
import pandas as pd
from jsonargparse import ActionConfigFile, ArgumentParser
from PIL import Image


import os


basic_data_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)


def parse_opt():
    parser = ArgumentParser()
    parser.add_argument("--cfg", action=ActionConfigFile)
    parser.add_argument(
        "--batch_size",
        type=int,
        default=32,
        help="Number of data to process simultanously",
    )
    parser.add_argument(
        "--device", type=str, default="cpu", help="device used for inference"
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default="/tmp/data/",
        required=True,
        help="Path to the data directory, where to download the dataset",
    )
    parser.add_argument(
        "--csv_file",
        type=str,
        default="images_family_infer.csv",
        required=True,
        help="Path to the csv file linking each image to a class outputed by the model",
    )
    parser.add_argument(
        "--model_filepath",
        type=str,
        default="resnet18_level0.ckpt",
        required=True,
        help="Path to the model checkpoint",
    )
    parser.add_argument(
        "--res_filepath",
        type=str,
        default="res0.csv",
        required=True,
        help="Path to the csv file output, with one pred per data",
    )
    opt = parser.parse_args()
    return opt


def make_model(
    model_archi, model_init_weights=None, state_dict=None, nb_classes=50, V1=False
):
    if model_init_weights is None:
        weights = None
    else:
        if V1:
            weights = getattr(torchvision_models, model_init_weights).IMAGENET1K_V1
        else:
            weights = getattr(torchvision_models, model_init_weights).DEFAULT
    model = getattr(torchvision_models, model_archi.lower())(weights=weights)
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(num_ftrs, nb_classes)
    if state_dict is not None:
        model.load_state_dict(state_dict)
    return model


class AircraftDataset(Dataset):
    def __init__(self, df, img_fold, transform=None):
        self.image_ids = df["id"].values
        self.label = df["class"].values
        self.transform = transform
        self.root = img_fold

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, index):
        img_path = os.path.join(self.root, str(self.image_ids[index]) + ".jpg")
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label = self.label[index]
        return str(self.image_ids[index]), img, label


def make_inference(
    model_filepath, device, batch_size, img_folder, csv_images, res_filepath
):
    # Load model
    ckpt = torch.load(model_filepath, map_location=device)
    state_dict = {
        ".".join(k.split(".")[1::]): ckpt["state_dict"][k]
        for k in ckpt["state_dict"].keys()
        if k.split(".")[0] == "model"
    }
    model_nb_classes = state_dict["fc.bias"].shape[0]
    model = make_model(
        model_archi="resnet18",
        model_init_weights=None,
        nb_classes=model_nb_classes,
        state_dict=state_dict,
    )
    model.eval()
    model.to(device)
    print(model)

    df_images = pd.read_csv(csv_images, dtype=str, keep_default_na=False)
    infer_dataset = AircraftDataset(
        df=df_images,
        img_fold=img_folder,
        transform=basic_data_transforms,
    )
    infer_loader = torch.utils.data.DataLoader(
        infer_dataset, batch_size=batch_size, shuffle=False, num_workers=2
    )

    with open(res_filepath, "w") as f:
        f.write("image_id,truth,pred,maxlogit\n")

    for _, (images_ids, images, labels) in tqdm(
        enumerate(infer_loader), total=len(infer_loader)
    ):
        images = images.to(device)
        logits = model(images).detach().to("cpu")
        maxlogits, preds = torch.max(logits, 1)
        res = ""
        for image_id, label, maxlog, pred in zip(images_ids, labels, maxlogits, preds):
            res += f"{image_id},{label},{pred},{maxlog}\n"
        with open(res_filepath, "a") as f:
            f.write(res)


if __name__ == "__main__":
    args = parse_opt()
    download_dir = os.path.join(args.data_dir, "download")
    print(f"Download the data in {download_dir}")
    dataset = datasets.FGVCAircraft(root=download_dir, split="trainval", download=True)

    jpg_folder = os.path.join(download_dir, "fgvc-aircraft-2013b/data/images/")
    print(f"Start running the inference")
    make_inference(
        model_filepath=args.model_filepath,
        csv_images=args.csv_file,
        res_filepath=args.res_filepath,
        img_folder=jpg_folder,
        device=args.device,
        batch_size=args.batch_size,
    )
    print(f"Inference done, results written in {args.res_filepath}")
