import collections
import logging
import threading
import uuid

import datasets
import gradio as gr
import pandas as pd

import leaderboard
from io_utils import (
    read_column_mapping,
    write_column_mapping,
    read_scanners,
    write_scanners,
)
from run_jobs import save_job_to_pipe
from text_classification import (
    check_model_task,
    preload_hf_inference_api,
    get_example_prediction,
    get_labels_and_features_from_dataset,
    check_hf_token_validity,
    HuggingFaceInferenceAPIResponse,
)
from wordings import (
    EXAMPLE_MODEL_ID,
    CHECK_CONFIG_OR_SPLIT_RAW,
    CONFIRM_MAPPING_DETAILS_FAIL_RAW,
    MAPPING_STYLED_ERROR_WARNING,
    NOT_FOUND_DATASET_RAW,
    NOT_FOUND_MODEL_RAW,
    NOT_TEXT_CLASSIFICATION_MODEL_RAW,
    UNMATCHED_MODEL_DATASET_STYLED_ERROR,
    CHECK_LOG_SECTION_RAW,
    VALIDATED_MODEL_DATASET_STYLED,
    get_dataset_fetch_error_raw,
)
import os
from app_env import HF_WRITE_TOKEN

MAX_LABELS = 40
MAX_FEATURES = 20

ds_dict = None
ds_config = None


def get_related_datasets_from_leaderboard(model_id, dataset_id_input):
    records = leaderboard.records
    model_records = records[records["model_id"] == model_id]
    datasets_unique = list(model_records["dataset_id"].unique())

    if len(datasets_unique) == 0:
        return gr.update(choices=[])

    if dataset_id_input in datasets_unique:
        return gr.update(choices=datasets_unique)

    return gr.update(choices=datasets_unique, value="")


logger = logging.getLogger(__file__)


def get_dataset_splits(dataset_id, dataset_config):
    try:
        splits = datasets.get_dataset_split_names(
            dataset_id, dataset_config, trust_remote_code=True
        )
        return gr.update(choices=splits, value=splits[0], visible=True)
    except Exception as e:
        logger.warning(
            f"Check your dataset {dataset_id} and config {dataset_config}: {e}"
        )
    return gr.update(visible=False)


def check_dataset(dataset_id):
    logger.info(f"Loading {dataset_id}")
    if not dataset_id or len(dataset_id) == 0:
        return (gr.update(visible=False), gr.update(visible=False), "")

    try:
        configs = datasets.get_dataset_config_names(dataset_id, trust_remote_code=True)
        if len(configs) == 0:
            return (gr.update(visible=False), gr.update(visible=False), "")
        splits = datasets.get_dataset_split_names(
            dataset_id, configs[0], trust_remote_code=True
        )
        return (
            gr.update(choices=configs, value=configs[0], visible=True),
            gr.update(choices=splits, value=splits[0], visible=True),
            "",
        )
    except Exception as e:
        logger.warning(f"Check your dataset {dataset_id}: {e}")
        if "doesn't exist on the Hub or cannot be accessed" in str(e):
            gr.Warning(NOT_FOUND_DATASET_RAW)
        elif "forbidden" in str(e).lower():
            # GSK-2770: illegal name
            gr.Warning(get_dataset_fetch_error_raw(e))
        else:
            # Unknown error
            gr.Warning(get_dataset_fetch_error_raw(e))

    return (gr.update(visible=False), gr.update(visible=False), "")


def empty_column_mapping(uid):
    write_column_mapping(None, uid)


def write_column_mapping_to_config(uid, *labels):
    # TODO: Substitute 'text' with more features for zero-shot
    # we are not using ds features because we only support "text" for now
    all_mappings = read_column_mapping(uid)

    if labels is None:
        return
    all_mappings = export_mappings(all_mappings, "labels", None, labels[:MAX_LABELS])
    all_mappings = export_mappings(
        all_mappings,
        "features",
        ["text"],
        labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)],
    )

    write_column_mapping(all_mappings, uid)


def export_mappings(all_mappings, key, subkeys, values):
    if key not in all_mappings.keys():
        all_mappings[key] = dict()
    if subkeys is None:
        subkeys = list(all_mappings[key].keys())

    if not subkeys:
        logging.debug(f"subkeys is empty for {key}")
        return all_mappings

    for i, subkey in enumerate(subkeys):
        if subkey:
            all_mappings[key][subkey] = values[i % len(values)]
    return all_mappings


def list_labels_and_features_from_dataset(ds_labels, ds_features, model_labels, uid):
    all_mappings = read_column_mapping(uid)
    # For flattened raw datasets with no labels
    # check if there are shared labels between model and dataset
    shared_labels = set(model_labels).intersection(set(ds_labels))
    if shared_labels:
        ds_labels = list(shared_labels)
    if len(ds_labels) > MAX_LABELS:
        ds_labels = ds_labels[:MAX_LABELS]
        gr.Warning(
            f"Too many labels to display for this spcae. We do not support more than {MAX_LABELS} in this space. You can use cli tool at https://github.com/Giskard-AI/cicd."
        )

    # sort labels to make sure the order is consistent
    # prediction gives the order based on probability
    ds_labels.sort()
    model_labels.sort()

    lables = [
        gr.Dropdown(
            label=f"{label}",
            choices=model_labels,
            value=model_labels[i % len(model_labels)],
            interactive=True,
            visible=True,
        )
        for i, label in enumerate(ds_labels)
    ]
    lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
    all_mappings = export_mappings(all_mappings, "labels", ds_labels, model_labels)

    # TODO: Substitute 'text' with more features for zero-shot
    features = [
        gr.Dropdown(
            label=f"{feature}",
            choices=ds_features,
            value=ds_features[0],
            interactive=True,
            visible=True,
        )
        for feature in ["text"]
    ]
    features += [
        gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
    ]
    all_mappings = export_mappings(all_mappings, "features", ["text"], ds_features)
    write_column_mapping(all_mappings, uid)

    return lables + features


def precheck_model_ds_enable_example_btn(
    model_id, dataset_id, dataset_config, dataset_split
):
    model_task = check_model_task(model_id)
    if not model_task:
        # Model might be not found
        error_msg_html = f"<p style='color: red;'>{NOT_FOUND_MODEL_RAW}</p>"
        if model_id.startswith("http://") or model_id.startswith("https://"):
            error_msg = f"Please input your model id, such as {EXAMPLE_MODEL_ID}, instead of URL"
            error_msg_html = f"<p style='color: red;'>{error_msg}</p>"

        return (
            gr.update(interactive=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(value=error_msg_html, visible=True),
        )

    if model_task != "text-classification":
        gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
        return (
            gr.update(interactive=False),
            gr.update(value=df, visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(
                value=f"<p style='color: red;'>{NOT_TEXT_CLASSIFICATION_MODEL_RAW}",
                visible=True,
            ),
        )

    preload_hf_inference_api(model_id)

    if dataset_config is None or dataset_split is None or len(dataset_config) == 0:
        return (
            gr.update(interactive=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )

    try:
        ds = datasets.load_dataset(dataset_id, dataset_config, trust_remote_code=True)
        df: pd.DataFrame = ds[dataset_split].to_pandas().head(5)
        ds_labels, ds_features, _ = get_labels_and_features_from_dataset(
            ds[dataset_split]
        )

        if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
            gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
            return (
                gr.update(interactive=False),
                gr.update(value=df, visible=True),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=False),
            )

        return (
            gr.update(interactive=True),
            gr.update(value=df, visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )
    except Exception as e:
        # Config or split wrong
        logger.warning(
            f"Check your dataset {dataset_id} and config {dataset_config} on split {dataset_split}: {e}"
        )
        return (
            gr.update(interactive=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
        )


def align_columns_and_show_prediction(
    model_id,
    dataset_id,
    dataset_config,
    dataset_split,
    uid,
    inference_token,
):
    model_task = check_model_task(model_id)
    if model_task is None or model_task != "text-classification":
        gr.Warning(NOT_TEXT_CLASSIFICATION_MODEL_RAW)
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            gr.update(interactive=False),
            "",
            *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
        )

    dropdown_placement = [
        gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
    ]

    hf_token = os.environ.get(HF_WRITE_TOKEN, default="")

    prediction_input, prediction_response = get_example_prediction(
        model_id, dataset_id, dataset_config, dataset_split, hf_token
    )

    if prediction_input is None or prediction_response is None:
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            gr.update(interactive=False),
            "",
            *dropdown_placement,
        )

    if isinstance(prediction_response, HuggingFaceInferenceAPIResponse):
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            gr.update(interactive=False),
            f"Hugging Face Inference API is loading your model. {prediction_response.message}",
            *dropdown_placement,
        )

    model_labels = list(prediction_response.keys())

    ds = datasets.load_dataset(
        dataset_id, dataset_config, split=dataset_split, trust_remote_code=True
    )
    ds_labels, ds_features, _ = get_labels_and_features_from_dataset(ds)

    # when dataset does not have labels or features
    if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
        gr.Warning(CHECK_CONFIG_OR_SPLIT_RAW)
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            gr.update(interactive=False),
            "",
            *dropdown_placement,
        )

    if len(ds_labels) != len(model_labels):
        return (
            gr.update(value=UNMATCHED_MODEL_DATASET_STYLED_ERROR, visible=True),
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            gr.update(interactive=False),
            "",
            *dropdown_placement,
        )

    column_mappings = list_labels_and_features_from_dataset(
        ds_labels,
        ds_features,
        model_labels,
        uid,
    )

    # when labels or features are not aligned
    # show manually column mapping
    if (
        collections.Counter(model_labels) != collections.Counter(ds_labels)
        or ds_features[0] != "text"
    ):
        return (
            gr.update(value=MAPPING_STYLED_ERROR_WARNING, visible=True),
            gr.update(
                value=prediction_input,
                lines=min(len(prediction_input) // 225 + 1, 5),
                visible=True,
            ),
            gr.update(value=prediction_response, visible=True),
            gr.update(visible=True, open=True),
            gr.update(interactive=(inference_token != "")),
            "",
            *column_mappings,
        )

    return (
        gr.update(value=VALIDATED_MODEL_DATASET_STYLED, visible=True),
        gr.update(
            value=prediction_input,
            lines=min(len(prediction_input) // 225 + 1, 5),
            visible=True,
        ),
        gr.update(value=prediction_response, visible=True),
        gr.update(visible=True, open=False),
        gr.update(interactive=(inference_token != "")),
        "",
        *column_mappings,
    )


def check_column_mapping_keys_validity(all_mappings):
    if all_mappings is None:
        logger.warning("all_mapping is None")
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return False

    if "labels" not in all_mappings.keys():
        logger.warning(f"Label mapping is not valid, all_mappings: {all_mappings}")
        return False

    return True


def enable_run_btn(
    uid, inference_token, model_id, dataset_id, dataset_config, dataset_split
):
    if inference_token == "":
        logger.warning("Inference API is not enabled")
        return gr.update(interactive=False)
    if (
        model_id == ""
        or dataset_id == ""
        or dataset_config == ""
        or dataset_split == ""
    ):
        logger.warning("Model id or dataset id is not selected")
        return gr.update(interactive=False)

    all_mappings = read_column_mapping(uid)
    if not check_column_mapping_keys_validity(all_mappings):
        logger.warning("Column mapping is not valid")
        return gr.update(interactive=False)

    if not check_hf_token_validity(inference_token):
        logger.warning("HF token is not valid")
        return gr.update(interactive=False)
    return gr.update(interactive=True)


def construct_label_and_feature_mapping(
    all_mappings, ds_labels, ds_features, label_keys=None
):
    label_mapping = {}
    if len(all_mappings["labels"].keys()) != len(ds_labels):
        logger.warning(
            f"""Label mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}. 
                    \nall_mappings: {all_mappings}\nds_labels: {ds_labels}"""
        )

    if len(all_mappings["features"].keys()) != len(ds_features):
        logger.warning(
            f"""Feature mapping corrupted: {CONFIRM_MAPPING_DETAILS_FAIL_RAW}. 
                    \nall_mappings: {all_mappings}\nds_features: {ds_features}"""
        )

    for i, label in zip(range(len(ds_labels)), ds_labels):
        # align the saved labels with dataset labels order
        label_mapping.update({str(i): all_mappings["labels"][label]})

    if "features" not in all_mappings.keys():
        logger.warning("features not in all_mappings")
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)

    feature_mapping = all_mappings["features"]
    if len(label_keys) > 0:
        feature_mapping.update({"label": label_keys[0]})
    return label_mapping, feature_mapping


def show_hf_token_info(token):
    valid = check_hf_token_validity(token)
    if not valid:
        return gr.update(visible=True)
    return gr.update(visible=False)


def try_submit(m_id, d_id, config, split, inference_token, uid, verbose):
    all_mappings = read_column_mapping(uid)
    if not check_column_mapping_keys_validity(all_mappings):
        return (gr.update(interactive=True), gr.update(visible=False))

    # get ds labels and features again for alignment
    ds = datasets.load_dataset(d_id, config, split=split, trust_remote_code=True)
    ds_labels, ds_features, label_keys = get_labels_and_features_from_dataset(ds)
    label_mapping, feature_mapping = construct_label_and_feature_mapping(
        all_mappings, ds_labels, ds_features, label_keys
    )

    eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
    save_job_to_pipe(
        uid,
        (
            m_id,
            d_id,
            config,
            split,
            inference_token,
            uid,
            label_mapping,
            feature_mapping,
            verbose,
        ),
        eval_str,
        threading.Lock(),
    )
    gr.Info("Your evaluation has been submitted")

    new_uid = uuid.uuid4()
    scanners = read_scanners(uid)
    write_scanners(scanners, new_uid)

    return (
        gr.update(interactive=False),  # Submit button
        gr.update(
            value=f"{CHECK_LOG_SECTION_RAW}Your job id is: {uid}. ",
            lines=5,
            visible=True,
            interactive=False,
        ),
        new_uid,  # Allocate a new uuid
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
        gr.update(visible=False),
    )