import os
import logging
import gradio as gr
from typing import Iterator
from gateway import request_generation

# Setup logging
logging.basicConfig(level=logging.INFO)

# Validate environment variables
CLOUD_GATEWAY_API = os.getenv("API_ENDPOINT")
if not CLOUD_GATEWAY_API:
    raise EnvironmentError("API_ENDPOINT is not set.")

MODEL_NAME: str = os.getenv("MODEL_NAME")
if not MODEL_NAME:
    raise EnvironmentError("MODEL_NAME is not set.")

# Get API Key
API_KEY = os.getenv("API_KEY")
if not API_KEY:  # simple check to validate API Key
    raise Exception("API Key not valid.")

# Create a header, avoid declaring multiple times
HEADER = {"x-api-key": f"{API_KEY}"}

def generate(
    message: str,
    chat_history: list,
    system_prompt: str,
    temperature: float = 0.6,
    frequency_penalty: float = 0.0,
    presence_penalty: float = 0.0,
) -> Iterator[str]:
    """Send a request to backend, fetch the streaming responses and emit to the UI.

    Args:
        message (str): input message from the user
        chat_history (list[tuple[str, str]]): entire chat history of the session
        system_prompt (str): system prompt
        temperature (float, optional): the value used to module the next token probabilities. Defaults to 0.6.
        top_p (float, optional): if set to float<1, only the smallest set of most probable tokens with probabilities
                                    that add up to top_p or higher are kept for generation. Defaults to 0.9.
        top_k (int, optional): the number of highest probability vocabulary tokens to keep for top-k-filtering.
                                Defaults to 50.
        repetition_penalty (float, optional): the parameter for repetition penalty. 1.0 means no penalty.
                                Defaults to 1.2.

    Yields:
        Iterator[str]: Streaming responses to the UI
    """
    # sample method to yield responses from the llm model
    outputs = []
    for text in request_generation(
        header=HEADER,
        message=message,
        system_prompt=system_prompt,
        temperature=temperature,
        presence_penalty=presence_penalty,
        frequency_penalty=frequency_penalty,
        cloud_gateway_api=CLOUD_GATEWAY_API,
        model_name=MODEL_NAME,
    ):
        outputs.append(text)
        yield "".join(outputs)


description = """
This Space is an Alpha release that demonstrates the [Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B) model running on AMD MI300 infrastructure. The space is built with Qwen 3 [License](https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/LICENSE). Feel free to play with it!
"""

demo = gr.ChatInterface(
    fn=generate,
    type="messages",
    chatbot=gr.Chatbot(
        type="messages",
        scale=2,
        allow_tags=True,
    ),
    stop_btn=None,
    additional_inputs=[
        gr.Textbox(
            label="System prompt",
            value="You are a highly capable AI assistant. Provide accurate, concise, and fact-based responses that are directly relevant to the user's query. Avoid speculation, ensure logical consistency, and maintain clarity in longer outputs. Keep answers well-structured and under 1200 tokens unless explicitly requested otherwise.",
            lines=3,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.3,
        ),
        gr.Slider(
            label="Frequency penalty",
            minimum=-2.0,
            maximum=2.0,
            step=0.1,
            value=0.0,
        ),
        gr.Slider(
            label="Presence penalty",
            minimum=-2.0,
            maximum=2.0,
            step=0.1,
            value=0.0,
        ),
    ],
    examples=[
        ["Plan a three-day trip to Washington DC for Cherry Blossom Festival."],
        [
            "Compose a short, joyful musical piece for kids celebrating spring sunshine and blossom."
        ],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Explain the plot of Cinderella in a sentence."],
        ["How many hours does it take a man to eat a Helicopter?"],
        ["Write a 100-word article on 'Benefits of Open-Source in AI research'."],
    ],
    cache_examples=False,
    title="Qwen3-30B-A3B",
    description=description,
)


if __name__ == "__main__":
    demo.queue(
        max_size=int(os.getenv("QUEUE")),
        default_concurrency_limit=int(os.getenv("CONCURRENCY_LIMIT")),
    ).launch()