import os

import onnxruntime
import torch
import torch.utils.data
import torchvision
from torch import nn
from torchvision.transforms.functional import InterpolationMode

import utils


def evaluate(
    criterion,
    data_loader,
    device,
    model=None,
    model_onnx_path=None,
    print_freq=100,
    log_suffix="",
):
    if model_onnx_path:
        session = onnxruntime.InferenceSession(
            model_onnx_path, providers=["CPUExecutionProvider"]
        )
        input_name = session.get_inputs()[0].name

    metric_logger = utils.MetricLogger(delimiter="  ")
    header = f"Test: {log_suffix}"

    num_processed_samples = 0
    with torch.inference_mode():
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
            target = target.to(device, non_blocking=True)
            image = image.to(device)

            if model_onnx_path:
                # from torch to numpy (ort)
                input_data = image.cpu().numpy()

                output_data = session.run([], {input_name: input_data})[0]

                # from numpy to torch
                output = torch.from_numpy(output_data).to(device)
            elif model:
                output = model(image)

            loss = criterion(output, target)

            acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
            # FIXME need to take into account that the datasets
            # could have been padded in distributed setup
            batch_size = image.shape[0]
            metric_logger.update(loss=loss.item())
            metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
            metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
            num_processed_samples += batch_size
    # gather the stats from all processes

    metric_logger.synchronize_between_processes()

    print(
        f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}"
    )
    return metric_logger.acc1.global_avg


def load_data(valdir):
    # Data loading code
    print("Loading data")
    interpolation = InterpolationMode("bilinear")

    preprocessing = torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize(256, interpolation=interpolation),
            torchvision.transforms.CenterCrop(224),
            torchvision.transforms.PILToTensor(),
            torchvision.transforms.ConvertImageDtype(torch.float),
            torchvision.transforms.Normalize(
                mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
            ),
        ]
    )

    dataset_test = torchvision.datasets.ImageFolder(
        valdir,
        preprocessing,
    )

    print("Creating data loaders")
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    return dataset_test, test_sampler


def main(args):
    print(args)

    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    val_dir = os.path.join(args.data_path, "val")
    dataset_test, test_sampler = load_data(val_dir)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=args.batch_size,
        sampler=test_sampler,
        num_workers=args.workers,
        pin_memory=True,
    )

    print("Creating model")

    criterion = nn.CrossEntropyLoss()

    model = None
    if args.model_ckpt:
        checkpoint = torch.load(args.model_ckpt, map_location="cpu")
        model = checkpoint["model_ckpt"]
        if "model_ema" in checkpoint:
            state_dict = {}
            for key, value in checkpoint["model_ema"].items():
                if not "module." in key:
                    continue
                state_dict[key.replace("module.", "")] = value
            model.load_state_dict(state_dict)
        model = model.to(device)
        model.eval()

    accuracy = evaluate(
        model=model,
        model_onnx_path=args.model_onnx,
        criterion=criterion,
        data_loader=data_loader_test,
        device=device,
    )
    print(f"Model accuracy is: {accuracy}")


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(
        description="PyTorch Classification Training", add_help=add_help
    )

    parser.add_argument(
        "--data-path", default="datasets/imagenet", type=str, help="dataset path"
    )
    parser.add_argument(
        "-b",
        "--batch-size",
        default=32,
        type=int,
        help="images per gpu, the total batch size is $NGPU x batch_size",
    )
    parser.add_argument(
        "-j",
        "--workers",
        default=16,
        type=int,
        metavar="N",
        help="number of data loading workers (default: 16)",
    )
    parser.add_argument("--print-freq", default=10, type=int, help="print frequency")
    parser.add_argument(
        "--model-onnx", default="", type=str, help="path of .onnx checkpoint"
    )
    parser.add_argument(
        "--model-ckpt", default="", type=str, help="path of .pth checkpoint"
    )

    return parser


if __name__ == "__main__":
    args = get_args_parser().parse_args()
    main(args)