from __future__ import annotations

import functools
import io
import urllib
from typing import Tuple, List, Any

import huggingface_hub
import onnxruntime as rt
import pandas as pd
import numpy as np
import PIL.Image
import requests

import dbimutils
import piexif
import piexif.helper
from urllib.request import urlopen

import model

HF_TOKEN = ""
SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
CONV_MODEL_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2"
CONV2_MODEL_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2"
VIT_MODEL_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2"
MODEL_FILENAME = "model.onnx"
LABEL_FILENAME = "selected_tags.csv"


def change_model(model_name):
    global loaded_models

    if model_name == "SwinV2":
        model = load_model(SWIN_MODEL_REPO, MODEL_FILENAME)
    elif model_name == "ConvNext":
        model = load_model(CONV_MODEL_REPO, MODEL_FILENAME)
    elif model_name == "ConvNextV2":
        model = load_model(CONV2_MODEL_REPO, MODEL_FILENAME)
    elif model_name == "ViT":
        model = load_model(VIT_MODEL_REPO, MODEL_FILENAME)

    loaded_models[model_name] = model
    return loaded_models[model_name]


def load_model(model_repo: str, model_filename: str) -> rt.InferenceSession:
    path = huggingface_hub.hf_hub_download(
        model_repo, model_filename, use_auth_token=HF_TOKEN
    )
    model = rt.InferenceSession(path)
    return model


def load_labels() -> tuple[list[Any], list[Any], list[Any], list[Any]]:
    path = huggingface_hub.hf_hub_download(
        CONV2_MODEL_REPO, LABEL_FILENAME, use_auth_token=HF_TOKEN
    )
    df = pd.read_csv(path)

    tag_names = df["name"].tolist()
    rating_indexes = list(np.where(df["category"] == 9)[0])
    general_indexes = list(np.where(df["category"] == 0)[0])
    character_indexes = list(np.where(df["category"] == 4)[0])
    return tag_names, rating_indexes, general_indexes, character_indexes


def predict(
        image: PIL.Image.Image,
        model_name: str,
        general_threshold: float,
        character_threshold: float,
        tag_names: list[str],
        rating_indexes: list[np.int64],
        general_indexes: list[np.int64],
        character_indexes: list[np.int64],
):
    global loaded_models

    if isinstance(image, str):
        rawimage = dbimutils.read_img_from_url(image)
    elif isinstance(image, PIL.Image.Image):
        rawimage = image
    else:
        raise Exception("Invalid image type")

    image = rawimage

    model = loaded_models[model_name]
    if model is None:
        model = change_model(model_name)

    _, height, width, _ = model.get_inputs()[0].shape

    # Alpha to white
    image = image.convert("RGBA")
    new_image = PIL.Image.new("RGBA", image.size, "WHITE")
    new_image.paste(image, mask=image)
    image = new_image.convert("RGB")
    image = np.asarray(image)

    # PIL RGB to OpenCV BGR
    image = image[:, :, ::-1]

    image = dbimutils.make_square(image, height)
    image = dbimutils.smart_resize(image, height)
    image = image.astype(np.float32)
    image = np.expand_dims(image, 0)

    input_name = model.get_inputs()[0].name
    label_name = model.get_outputs()[0].name
    probs = model.run([label_name], {input_name: image})[0]

    labels = list(zip(tag_names, probs[0].astype(float)))

    # First 4 labels are actually ratings: pick one with argmax
    ratings_names = [labels[i] for i in rating_indexes]
    rating = dict(ratings_names)

    # Then we have general tags: pick any where prediction confidence > threshold
    general_names = [labels[i] for i in general_indexes]
    general_res = [x for x in general_names if x[1] > general_threshold]
    general_res = dict(general_res)

    # Everything else is characters: pick any where prediction confidence > threshold
    character_names = [labels[i] for i in character_indexes]
    character_res = [x for x in character_names if x[1] > character_threshold]
    character_res = dict(character_res)

    b = dict(sorted(general_res.items(), key=lambda item: item[1], reverse=True))
    a = (
        ", ".join(list(b.keys()))
        .replace("_", " ")
        .replace("(", "\(")
        .replace(")", "\)")
    )
    c = ", ".join(list(b.keys()))

    items = rawimage.info
    geninfo = ""

    if "exif" in rawimage.info:
        exif = piexif.load(rawimage.info["exif"])
        exif_comment = (exif or {}).get("Exif", {}).get(piexif.ExifIFD.UserComment, b"")
        try:
            exif_comment = piexif.helper.UserComment.load(exif_comment)
        except ValueError:
            exif_comment = exif_comment.decode("utf8", errors="ignore")

        items["exif comment"] = exif_comment
        geninfo = exif_comment

        for field in [
            "jfif",
            "jfif_version",
            "jfif_unit",
            "jfif_density",
            "dpi",
            "exif",
            "loop",
            "background",
            "timestamp",
            "duration",
        ]:
            items.pop(field, None)

    geninfo = items.get("parameters", geninfo)

    for key, text in items.items():
        print(key)
        print(text)

    print("geninfo", geninfo)
    print("a", a)
    print("c", c)
    print("rating", rating)
    print("character_res", character_res)
    print("general_res", general_res)

    character_res = list(filter(lambda x: x['confidence'] > 0.4, [{'tag': tag, 'confidence': score}
                                                                  for tag, score in character_res.items()]))

    general_res = list(filter(lambda x: x['confidence'] > 0.4, [{'tag': tag, 'confidence': score}
                                                                for tag, score in general_res.items()]))

    return {'a': a, 'c': c, 'rating': rating, 'character_res': character_res, 'general_res': general_res}


def label_img(
        image: PIL.Image.Image | str,
        model: str,
        # model: (["SwinV2", "ConvNext", "ConvNextV2", "ViT"], value="ConvNextV2", label="Model"),
        l_score_general_threshold: float,
        l_score_character_threshold: float,
):
    if isinstance(image, str) and image.startswith("http"):
        image = dbimutils.read_img_from_url(image)

    global loaded_models
    loaded_models = {"SwinV2": None, "ConvNext": None, "ConvNextV2": None, "ViT": None}

    change_model("ConvNextV2")

    tag_names, rating_indexes, general_indexes, character_indexes = load_labels()

    func = functools.partial(
        predict,
        tag_names=tag_names,
        rating_indexes=rating_indexes,
        general_indexes=general_indexes,
        character_indexes=character_indexes,
    )

    return func(
        image=image, model_name=model,
        general_threshold=l_score_general_threshold,
        character_threshold=l_score_character_threshold,
    )


def write_image_tag(img_id: int, is_valid: bool, tags: List[model.ImageTag], callback_url: str):
    model.ImageScanCallbackRequest(img_id=img_id, is_valid=is_valid, tags=tags)


if __name__ == "__main__":
    score_slider_step = 0.05
    score_general_threshold = 0.35
    score_character_threshold = 0.85

    ret = label_img(
        image='https://pub-9747017e9ec54620bfbe2385f14fe4d7.r2.dev/cnGirlYcy_v10_people_network_nannansleep/cnGirlYcy_v10_people_network_nannansleep_r_1679670778_0.png',
        model="SwinV2",
        l_score_general_threshold=score_general_threshold,
        l_score_character_threshold=score_character_threshold,
    )
    print(ret)