import collections
import json
import logging
import os
import threading

import datasets
import gradio as gr
from transformers.pipelines import TextClassificationPipeline

from io_utils import (
    read_column_mapping,
    save_job_to_pipe,
    write_column_mapping,
    write_log_to_user_file,
)
from text_classification import (
    check_model,
    get_example_prediction,
    get_labels_and_features_from_dataset,
)
from wordings import CONFIRM_MAPPING_DETAILS_FAIL_RAW

MAX_LABELS = 20
MAX_FEATURES = 20

HF_REPO_ID = "HF_REPO_ID"
HF_SPACE_ID = "SPACE_ID"
HF_WRITE_TOKEN = "HF_WRITE_TOKEN"
CONFIG_PATH = "./config.yaml"


def check_dataset_and_get_config(dataset_id):
    try:
        write_column_mapping(None)
        configs = datasets.get_dataset_config_names(dataset_id)
        return gr.Dropdown(configs, value=configs[0], visible=True)
    except Exception:
        # Dataset may not exist
        pass


def check_dataset_and_get_split(dataset_id, dataset_config):
    try:
        splits = list(datasets.load_dataset(dataset_id, dataset_config).keys())
        return gr.Dropdown(splits, value=splits[0], visible=True)
    except Exception:
        # Dataset may not exist
        # gr.Warning(f"Failed to load dataset {dataset_id} with config {dataset_config}: {e}")
        pass


def write_column_mapping_to_config(dataset_id, dataset_config, dataset_split, *labels):
    ds_labels, ds_features = get_labels_and_features_from_dataset(
        dataset_id, dataset_config, dataset_split
    )
    if labels is None:
        return
    labels = [*labels]
    all_mappings = read_column_mapping(CONFIG_PATH)

    if all_mappings is None:
        all_mappings = dict()

    if "labels" not in all_mappings.keys():
        all_mappings["labels"] = dict()
    for i, label in enumerate(labels[:MAX_LABELS]):
        if label:
            all_mappings["labels"][label] = ds_labels[i]

    if "features" not in all_mappings.keys():
        all_mappings["features"] = dict()
    for i, feat in enumerate(labels[MAX_LABELS : (MAX_LABELS + MAX_FEATURES)]):
        if feat:
            all_mappings["features"][feat] = ds_features[i]
    write_column_mapping(all_mappings)


def list_labels_and_features_from_dataset(ds_labels, ds_features, model_id2label):
    model_labels = list(model_id2label.values())
    len_model_labels = len(model_labels)
    print(model_labels, model_id2label, 3 % len_model_labels)
    lables = [
        gr.Dropdown(
            label=f"{label}",
            choices=model_labels,
            value=model_id2label[i % len_model_labels],
            interactive=True,
            visible=True,
        )
        for i, label in enumerate(ds_labels[:MAX_LABELS])
    ]
    lables += [gr.Dropdown(visible=False) for _ in range(MAX_LABELS - len(lables))]
    # TODO: Substitute 'text' with more features for zero-shot
    features = [
        gr.Dropdown(
            label=f"{feature}",
            choices=ds_features,
            value=ds_features[0],
            interactive=True,
            visible=True,
        )
        for feature in ["text"]
    ]
    features += [
        gr.Dropdown(visible=False) for _ in range(MAX_FEATURES - len(features))
    ]
    return lables + features


def check_model_and_show_prediction(
    model_id, dataset_id, dataset_config, dataset_split
):
    ppl = check_model(model_id)
    if ppl is None or not isinstance(ppl, TextClassificationPipeline):
        gr.Warning("Please check your model.")
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            *[gr.update(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)],
        )

    dropdown_placement = [
        gr.Dropdown(visible=False) for _ in range(MAX_LABELS + MAX_FEATURES)
    ]

    if ppl is None:  # pipeline not found
        gr.Warning("Model not found")
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            *dropdown_placement,
        )
    model_id2label = ppl.model.config.id2label
    ds_labels, ds_features = get_labels_and_features_from_dataset(
        dataset_id, dataset_config, dataset_split
    )

    # when dataset does not have labels or features
    if not isinstance(ds_labels, list) or not isinstance(ds_features, list):
        # gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=False, open=False),
            *dropdown_placement,
        )

    column_mappings = list_labels_and_features_from_dataset(
        ds_labels,
        ds_features,
        model_id2label,
    )

    # when labels or features are not aligned
    # show manually column mapping
    if (
        collections.Counter(model_id2label.values()) != collections.Counter(ds_labels)
        or ds_features[0] != "text"
    ):
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return (
            gr.update(visible=False),
            gr.update(visible=False),
            gr.update(visible=True, open=True),
            *column_mappings,
        )

    prediction_input, prediction_output = get_example_prediction(
        ppl, dataset_id, dataset_config, dataset_split
    )
    return (
        gr.update(value=prediction_input, visible=True),
        gr.update(value=prediction_output, visible=True),
        gr.update(visible=True, open=False),
        *column_mappings,
    )


def try_submit(m_id, d_id, config, split, local, uid):
    all_mappings = read_column_mapping(CONFIG_PATH)

    if all_mappings is None:
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return (gr.update(interactive=True), gr.update(visible=False))

    if "labels" not in all_mappings.keys():
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return (gr.update(interactive=True), gr.update(visible=False))
    label_mapping = all_mappings["labels"]

    if "features" not in all_mappings.keys():
        gr.Warning(CONFIRM_MAPPING_DETAILS_FAIL_RAW)
        return (gr.update(interactive=True), gr.update(visible=False))
    feature_mapping = all_mappings["features"]

    # TODO: Set column mapping for some dataset such as `amazon_polarity`
    if local:
        command = [
            "python",
            "cli.py",
            "--loader",
            "huggingface",
            "--model",
            m_id,
            "--dataset",
            d_id,
            "--dataset_config",
            config,
            "--dataset_split",
            split,
            "--hf_token",
            os.environ.get(HF_WRITE_TOKEN),
            "--discussion_repo",
            os.environ.get(HF_REPO_ID) or os.environ.get(HF_SPACE_ID),
            "--output_format",
            "markdown",
            "--output_portal",
            "huggingface",
            "--feature_mapping",
            json.dumps(feature_mapping),
            "--label_mapping",
            json.dumps(label_mapping),
            "--scan_config",
            "../config.yaml",
        ]

        eval_str = f"[{m_id}]<{d_id}({config}, {split} set)>"
        logging.info(f"Start local evaluation on {eval_str}")
        save_job_to_pipe(uid, command, threading.Lock())
        write_log_to_user_file(
            uid,
            f"Start local evaluation on {eval_str}. Please wait for your job to start...\n",
        )
        gr.Info(f"Start local evaluation on {eval_str}")

        return (
            gr.update(interactive=False),
            gr.update(lines=5, visible=True, interactive=False),
        )

    else:
        gr.Info("TODO: Submit task to an endpoint")

    return (gr.update(interactive=True), gr.update(visible=False))  # Submit button