import json
import os
from difflib import SequenceMatcher
from typing import Any, Dict, Optional, Tuple

from fastapi import FastAPI, Request, Response
from huggingface_hub import (DatasetCard, HfApi, ModelCard, comment_discussion,
                             create_discussion, get_discussion_details,
                             get_repo_discussions, login)
from huggingface_hub.utils import EntryNotFoundError
from tabulate import tabulate

KEY = os.environ.get("WEBHOOK_SECRET")
HF_TOKEN = os.environ.get("HF_TOKEN")

api = HfApi(token=HF_TOKEN)
login(HF_TOKEN)

app = FastAPI()


@app.get("/")
def read_root():
    data = """
    <h2 style="text-align:center">Metadata Review Bot</h2>
    <p style="text-align:center">This is a demo app showing how to use webhooks to automate metadata review for models and datasets shared on the Hugging Face Hub.</p>
    """
    return Response(content=data, media_type="text/html")


def similar(a, b):
    """Check similarity of two sequences"""
    return SequenceMatcher(None, a, b).ratio()


def create_metadata_key_dict(card_data, repo_type: str):
    shared_keys = ["tags", "license"]
    if repo_type == "model":
        model_keys = ["library_name", "datasets", "metrics", "co2", "pipeline_tag"]
        shared_keys.extend(model_keys)
        keys = shared_keys
        return {key: card_data.get(key) for key in keys}
    if repo_type == "dataset":
        data_keys = [
            "pretty_name",
            "size_categories",
            "task_categories",
            "task_ids",
            "source_datasets",
        ]
        shared_keys.extend(data_keys)
        keys = shared_keys
        return {key: card_data.get(key) for key in keys}


def create_metadata_breakdown_table(desired_metadata_dictionary):
    data = {k:v or "Field Missing" for k,v in desired_metadata_dictionary.items()}
    metadata_fields_column = list(data.keys())
    metadata_values_column = list(data.values())
    table_data = list(zip(metadata_fields_column, metadata_values_column))
    return tabulate(
        table_data, tablefmt="github", headers=("Metadata Field", "Provided Value")
    )


def calculate_grade(desired_metadata_dictionary):
    metadata_values = list(desired_metadata_dictionary.values())
    score = sum(1 if field else 0 for field in metadata_values) / len(metadata_values)
    return round(score, 2)


def create_markdown_report(
    desired_metadata_dictionary, repo_name, repo_type, score, update: bool = False
):
    report = f"""# {repo_type.title()} metadata report card {"(updated)" if update else ""}
    \n
This is an automatically produced metadata quality report card for {repo_name}. This report is meant as a POC!
    \n 
## Breakdown of metadata fields for your{repo_type}
\n
{create_metadata_breakdown_table(desired_metadata_dictionary)}
\n
You scored a metadata coverage grade of: **{score}**% \n {f"We're not angry we're just disappointed! {repo_type.title()} metadata is super important. Please try harder..."
if score <= 0.5 else f"Not too shabby! Make sure you also fill in a {repo_type} card too!"}
    """
    return report


def parse_webhook_post(data: Dict[str, Any]) -> Optional[Tuple[str, str]]:
    event = data["event"]
    if event["scope"] != "repo":
        return None
    repo = data["repo"]
    repo_name = repo["name"]
    repo_type = repo["type"]
    if repo_type not in {"model", "dataset"}:
        raise ValueError("Unknown hub type")
    return repo_type, repo_name


def load_repo_card_metadata(repo_type, repo_name):
    if repo_type == "dataset":
        try:
            return DatasetCard.load(repo_name).data.to_dict()
        except EntryNotFoundError:
            return {}
    if repo_type == "model":
        try:
            return ModelCard.load(repo_name).data.to_dict()
        except EntryNotFoundError:
            return {}


def create_or_update_report(data):
    if parsed_post := parse_webhook_post(data):
        repo_type, repo_name = parsed_post
    else:
        return Response("Unable to parse webhook data", status_code=400)
    card_data = load_repo_card_metadata(repo_type, repo_name)
    desired_metadata_dictionary = create_metadata_key_dict(card_data, repo_type)
    score = calculate_grade(desired_metadata_dictionary)
    report = create_markdown_report(
        desired_metadata_dictionary, repo_name, repo_type, score, update=False
    )
    repo_discussions = get_repo_discussions(
        repo_name,
        repo_type=repo_type,
    )
    for discussion in repo_discussions:
        if (
            discussion.title == "Metadata Report Card" and discussion.status == "open"
        ):  # An existing open report card thread
            discussion_details = get_discussion_details(
                repo_name, discussion.num, repo_type=repo_type
            )
            last_comment = discussion_details.events[-1].content
            if similar(report, last_comment) <= 0.999:
                report = create_markdown_report(
                    desired_metadata_dictionary,
                    repo_name,
                    repo_type,
                    score,
                    update=True,
                )
                comment_discussion(
                    repo_name,
                    discussion.num,
                    comment=report,
                    repo_type=repo_type,
                )
            return True
    create_discussion(
        repo_name,
        "Metadata Report Card",
        description=report,
        repo_type=repo_type,
    )
    return True


@app.post("/webhook")
async def webhook(request: Request):
    if request.method == "POST":
        if request.headers.get("X-Webhook-Secret") != KEY:
            return Response("Invalid secret", status_code=401)
        data = await request.json()
        result = create_or_update_report(data)
        return "Webhook received!" if result else result