import os
import sys
import time
from functools import wraps
from typing import Any, Literal

from gradio import ChatMessage
from gradio.components.chatbot import Message

COMMUNITY_POSTFIX_URL = "/discussions"
DEBUG_MODE = False or os.environ.get("DEBUG_MODE") == "True"
DEBUG_MODEL = False or os.environ.get("DEBUG_MODEL") == "True"

models_config = {
    "Apriel-Nemotron-15b-Thinker": {
        "MODEL_DISPLAY_NAME": "Apriel-Nemotron-15b-Thinker",
        "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker",
        "MODEL_NAME": os.environ.get("MODEL_NAME_NEMO_15B"),
        "VLLM_API_URL": os.environ.get("VLLM_API_URL_NEMO_15B"),
        "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
        "REASONING": True
    },
    "Apriel-5b": {
        "MODEL_DISPLAY_NAME": "Apriel-5b",
        "MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct",
        "MODEL_NAME": os.environ.get("MODEL_NAME_5B"),
        "VLLM_API_URL": os.environ.get("VLLM_API_URL_5B"),
        "AUTH_TOKEN": os.environ.get("AUTH_TOKEN"),
        "REASONING": False
    }
}


def get_model_config(model_name: str) -> dict:
    config = models_config.get(model_name)
    if not config:
        raise ValueError(f"Model {model_name} not found in models_config")
    if not config.get("MODEL_NAME"):
        raise ValueError(f"Model name not found in config for {model_name}")
    if not config.get("VLLM_API_URL"):
        raise ValueError(f"VLLM API URL not found in config for {model_name}")

    return config


def _log_message(prefix, message, icon=""):
    timestamp = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    if len(icon) > 0:
        icon = f"{icon} "
    print(f"{timestamp}: {prefix} {icon}{message}")


def log_debug(message):
    if DEBUG_MODE is True:
        _log_message("DEBUG", message)


def log_info(message):
    _log_message("INFO ", message)


def log_warning(message):
    _log_message("WARN ", message, "⚠️")


def log_error(message):
    _log_message("ERROR", message, "‼️")


# Gradio 5.0.1 had issues with checking the message formats.  5.29.0 does not!
def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None:
    if not DEBUG_MODE:
        return

    if type == "messages":
        all_valid = all(
            isinstance(message, dict)
            and "role" in message
            and "content" in message
            or isinstance(message, ChatMessage | Message)
            for message in messages
        )
        if not all_valid:
            # Display which message is not valid
            for i, message in enumerate(messages):
                if not (isinstance(message, dict) and
                        "role" in message and
                        "content" in message) and not isinstance(message, ChatMessage | Message):
                    print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr)
                    break

            raise Exception(
                "Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object."
            )
        # else:
        #     print("_check_format() --> All messages are valid.")
    elif not all(
            isinstance(message, (tuple, list)) and len(message) == 2
            for message in messages
    ):
        raise Exception(
            "Data incompatible with tuples format. Each message should be a list of length 2."
        )


# Adds timing info for a gradio event handler (non-generator functions)
def logged_event_handler(log_msg='', event_handler=None, log_timer=None, clear_timer=False):
    @wraps(event_handler)
    def wrapped_event_handler(*args, **kwargs):
        # Log before
        if log_timer:
            if clear_timer:
                log_timer.clear()
            log_timer.add_step(f"Start: {log_debug}")
        log_debug(f"::: Before event: {log_msg}")

        # Call the original event handler
        result = event_handler(*args, **kwargs)

        # Log after
        if log_timer:
            log_timer.add_step(f"Completed: {log_msg}")
        log_debug(f"::: After event: {log_msg}")

        return result

    return wrapped_event_handler