import os
import torch
import spaces
import psycopg2
import gradio as gr
from threading import Thread
from collections.abc import Iterator
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gc

# Constants
MAX_MAX_NEW_TOKENS = 4096
MAX_INPUT_TOKEN_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 2048
HF_TOKEN = os.environ.get("HF_TOKEN", "")

# Language lists
INDIC_LANGUAGES = [
    "Hindi", "Bengali", "Telugu", "Marathi", "Tamil", "Urdu", "Gujarati", 
    "Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili", 
    "Santali", "Kashmiri", "Nepali", "Sindhi", "Konkani", "Dogri", 
    "Manipuri", "Bodo", "English", "Sanskrit"
]

SARVAM_LANGUAGES = INDIC_LANGUAGES

# Model configurations with optimizations
TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
DEVICE_MAP = "cuda:0" if torch.cuda.is_available() else "cpu"

indictrans_model = AutoModelForCausalLM.from_pretrained(
    "ai4bharat/IndicTrans3-beta",
    torch_dtype=TORCH_DTYPE,
    device_map=DEVICE_MAP,
    token=HF_TOKEN,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

sarvam_model = AutoModelForCausalLM.from_pretrained(
    "sarvamai/sarvam-translate",
    torch_dtype=TORCH_DTYPE,
    device_map=DEVICE_MAP,
    token=HF_TOKEN,
    low_cpu_mem_usage=True,
    trust_remote_code=True
)

tokenizer = AutoTokenizer.from_pretrained(
    "ai4bharat/IndicTrans3-beta",
    trust_remote_code=True
)


def format_message_for_translation(message, target_lang):
    return f"Translate the following text to {target_lang}: {message}"

def store_feedback(rating, feedback_text, chat_history, tgt_lang, model_type):
    try:
        if not rating:
            gr.Warning("Please select a rating before submitting feedback.", duration=5)
            return None

        if not feedback_text or feedback_text.strip() == "":
            gr.Warning("Please provide some feedback before submitting.", duration=5)
            return None

        if not chat_history:
            gr.Warning("Please provide the input text before submitting feedback.", duration=5)
            return None

        if len(chat_history[0]) < 2:
            gr.Warning("Please translate the input text before submitting feedback.", duration=5)
            return None

        conn = psycopg2.connect(
            host=os.getenv("DB_HOST"),
            database=os.getenv("DB_NAME"),
            user=os.getenv("DB_USER"),
            password=os.getenv("DB_PASSWORD"),
            port=os.getenv("DB_PORT"),
        )

        cursor = conn.cursor()
        insert_query = """
        INSERT INTO feedback 
        (tgt_lang, rating, feedback_txt, chat_history, model_type)
        VALUES (%s, %s, %s, %s, %s)
        """
        cursor.execute(insert_query, (tgt_lang, int(rating), feedback_text, chat_history, model_type))
        conn.commit()
        cursor.close()
        conn.close()
        gr.Info("Thank you for your feedback! 🙏", duration=5)

    except Exception as e:
        print(f"Database error: {e}")
        gr.Error("An error occurred while storing feedback. Please try again later.", duration=5)

def store_output(tgt_lang, input_text, output_text, model_type):
    try:
        conn = psycopg2.connect(
            host=os.getenv("DB_HOST"),
            database=os.getenv("DB_NAME"),
            user=os.getenv("DB_USER"),
            password=os.getenv("DB_PASSWORD"),
            port=os.getenv("DB_PORT"),
        )
        cursor = conn.cursor()
        insert_query = """
        INSERT INTO translation
        (input_txt, output_txt, tgt_lang, model_type)
        VALUES (%s, %s, %s, %s)
        """
        cursor.execute(insert_query, (input_text, output_text, tgt_lang, model_type))
        conn.commit()
        cursor.close()
        conn.close()
    except Exception as e:
        print(f"Database error: {e}")

@spaces.GPU
def translate_message(
    message: str,
    chat_history: list[dict],
    target_language: str = "Hindi",
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
    model_type: str = "indictrans"
) -> Iterator[str]:
    
    if model_type == "indictrans":
        model = indictrans_model
    elif model_type == "sarvam":
        model = sarvam_model
    
    if model is None or tokenizer is None:
        yield "Error: Model failed to load. Please try again."
        return
    
    conversation = []
    translation_request = format_message_for_translation(message, target_language)
    conversation.append({"role": "user", "content": translation_request})

    try:
        input_ids = tokenizer.apply_chat_template(
            conversation, return_tensors="pt", add_generation_prompt=True
        )
        
        if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
            input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
            gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
        
        input_ids = input_ids.to(model.device)

        streamer = TextIteratorStreamer(
            tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True
        )
        
        generate_kwargs = {
            "input_ids": input_ids,
            "streamer": streamer,
            "max_new_tokens": max_new_tokens,
            "do_sample": True,
            "top_p": top_p,
            "top_k": top_k,
            "temperature": temperature,
            "num_beams": 1,
            "repetition_penalty": repetition_penalty,
            "use_cache": True,  # Enable KV cache
        }
        
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()

        outputs = []
        for text in streamer:
            outputs.append(text)
            yield "".join(outputs)

        # Clean up
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()
        
        store_output(target_language, message, "".join(outputs), model_type)
        
    except Exception as e:
        yield f"Translation error: {str(e)}"

# Enhanced CSS with beautiful styling
css = """
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap');

* {
    font-family: 'Inter', sans-serif;
    box-sizing: border-box;
}

.gradio-container {
    background: #1a1a1a !important;
    color: #e0e0e0;
    min-height: 100vh;
}

.main-container {
    background: #2a2a2a;
    border-radius: 12px;
    padding: 1.5rem;
    margin: 1rem;
    box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
}

.title-container {
    text-align: center;
    margin-bottom: 1.5rem;
    padding: 1rem;
    color: #a0a0ff;
}

.model-tab {
    background: #3333a0;
    border: none;
    border-radius: 8px;
    color: #ffffff;
    font-weight: 500;
    padding: 0.75rem 1.5rem;
    transition: all 0.2s ease;
}

.model-tab:hover {
    background: #4444b0;
    transform: translateY(-1px);
    box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4);
}

.language-dropdown {
    background: #333333;
    border: 1px solid #444444;
    border-radius: 8px;
    padding: 0.5rem;
    font-size: 14px;
    color: #e0e0e0;
    transition: all 0.2s ease;
}

.language-dropdown:focus {
    border-color: #6666ff;
    box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2);
}

.chat-container {
    background: #222222;
    border-radius: 8px;
    padding: 1rem;
    box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
    margin: 1rem 0;
}

.message-input {
    background: #333333;
    border: 1px solid #444444;
    border-radius: 8px;
    padding: 0.75rem;
    font-size: 14px;
    color: #e0e0e0;
    transition: all 0.2s ease;
}

.message-input:focus {
    border-color: #6666ff;
    box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2);
}

.translate-btn {
    background: #3333a0;
    border: none;
    border-radius: 8px;
    color: #ffffff;
    font-weight: 500;
    padding: 0.75rem 1.5rem;
    font-size: 14px;
    cursor: pointer;
    transition: all 0.2s ease;
}

.translate-btn:hover {
    background: #4444b0;
    transform: translateY(-1px);
    box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4);
}

.examples-container {
    background: #2a2a2a;
    border-radius: 8px;
    padding: 1rem;
    margin: 1rem 0;
}

.feedback-section {
    background: #2a2a2a;
    border-radius: 8px;
    padding: 1rem;
    margin: 1rem 0;
    border: none;
}

.advanced-options {
    background: #2a2a2a;
    border-radius: 8px;
    padding: 1rem;
    margin: 1rem 0;
}

.slider-container .gr-slider {
    background: #444444;
    color: #e0e0e0;
}

.rating-container {
    display: flex;
    gap: 0.5rem;
    justify-content: center;
    margin: 0.5rem 0;
}

.feedback-btn {
    background: #3333a0;
    border: none;
    border-radius: 8px;
    color: #ffffff;
    font-weight: 500;
    padding: 0.5rem 1rem;
    cursor: pointer;
    transition: all 0.2s ease;
}

.feedback-btn:hover {
    background: #4444b0;
    transform: translateY(-1px);
    box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4);
}

.stats-card {
    background: #333333;
    border-radius: 8px;
    padding: 0.75rem;
    text-align: center;
    box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
    margin: 0.5rem;
    color: #e0e0e0;
}

.model-info {
    background: #3333a0;
    color: #ffffff;
    border-radius: 8px;
    padding: 1rem;
    margin: 1rem 0;
}

.animate-pulse {
    animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}

@keyframes pulse {
    0%, 100% {
        opacity: 1;
    }
    50% {
        opacity: 0.5;
    }
}

.loading-spinner {
    border: 3px solid #444444;
    border-top: 3px solid #6666ff;
    border-radius: 50%;
    width: 30px;
    height: 30px;
    animation: spin 1.5s linear infinite;
    margin: 0 auto;
}

@keyframes spin {
    0% { transform: rotate(0deg); }
    100% { transform: rotate(360deg); }
}
"""

# Model descriptions
INDICTRANS_DESCRIPTION = """
<div class="model-info">
<h3>🌟 IndicTrans3-Beta</h3>
<p><strong>Latest SOTA translation model from AI4Bharat</strong></p>
<ul>
<li>✅ Supports <strong>22 Indic languages</strong></li>
<li>✅ Document-level machine translation</li>
<li>✅ Optimized for real-world applications</li>
<li>✅ Enhanced with KV caching for faster inference</li>
</ul>
</div>
"""

SARVAM_DESCRIPTION = """
<div class="model-info">
<h3>🚀 Sarvam Translate</h3>
<p><strong>Advanced multilingual translation model</strong></p>
<ul>
<li>✅ Supports <strong>22 Indic languages</strong></li>
<li>✅ High-quality translations</li>
<li>✅ Document-level machine translation</li>
<li>✅ Optimized for real-world applications</li>
<li>✅ Optimized for production use</li>
<li>✅ Enhanced with KV caching for faster inference</li>
</ul>
</div>
"""

def create_chatbot_interface(model_type, languages, description):
    with gr.Column(elem_classes="main-container"):
        gr.Markdown(description)
        
        target_language = gr.Dropdown(
            languages,
            value=languages[0],
            label="🌍 Select Target Language",
            elem_classes="language-dropdown",
        )

        chatbot = gr.Chatbot(
            height=500,
            elem_classes="chat-container",
            show_copy_button=True,
            avatar_images=["avatars/user_logo.png", "avatars/ai4bharat_logo.png"],
            bubble_full_width=False,
            show_label=False
        )

        with gr.Row():
            msg = gr.Textbox(
                placeholder="✍️ Enter text to translate...",
                show_label=False,
                container=False,
                scale=9,
                elem_classes="message-input",
            )
            submit_btn = gr.Button(
                "🔄 Translate", 
                scale=1,
                elem_classes="translate-btn"
            )

        # Examples section
        if model_type == "indictrans":
            examples_data = [
                "The Taj Mahal, an architectural marvel of white marble, stands majestically along the banks of the Yamuna River in Agra, India.",
                "Kumbh Mela, the world's largest spiritual gathering, is a significant Hindu festival held at four sacred riverbanks.",
                "India's classical dance forms, such as Bharatanatyam, Kathak, Odissi, are deeply rooted in tradition and storytelling.",
                "Ayurveda, India's ancient medical system, emphasizes a holistic approach to health by balancing mind, body, and spirit.",
                "Diwali, the festival of lights, symbolizes the victory of light over darkness and good over evil."
            ]
        else:
            examples_data = [
                "Hello, how are you today?",
                "I love learning new languages and cultures.",
                "Technology is transforming the way we communicate.",
                "The weather is beautiful today.",
                "Thank you for your help and support."
            ]

        with gr.Accordion("📚 Example Texts", open=False, elem_classes="examples-container"):
            gr.Examples(
                examples=examples_data,
                inputs=msg,
                label="Click on any example to try:"
            )

        # Feedback section
        with gr.Accordion("💭 Provide Feedback", open=False, elem_classes="feedback-section"):
            gr.Markdown("### 📝 Rate Translation & Share Feedback")
            gr.Markdown("Help us improve translation quality with your valuable feedback!")
            
            with gr.Row():
                rating = gr.Radio(
                    ["1", "2", "3", "4", "5"],
                    label="🏆 Translation Quality Rating",
                    value=None
                )

            feedback_text = gr.Textbox(
                placeholder="💬 Share your thoughts about the translation quality, accuracy, or suggestions for improvement...",
                label="📝 Your Feedback",
                lines=3,
            )

            feedback_submit = gr.Button(
                "📤 Submit Feedback",
                elem_classes="feedback-btn"
            )

        # Advanced options
        with gr.Accordion("⚙️ Advanced Settings", open=False, elem_classes="advanced-options"):
            gr.Markdown("### 🔧 Fine-tune Translation Parameters")
            
            with gr.Row():
                max_new_tokens = gr.Slider(
                    label="📏 Max New Tokens",
                    minimum=1,
                    maximum=MAX_MAX_NEW_TOKENS,
                    step=1,
                    value=DEFAULT_MAX_NEW_TOKENS,
                    elem_classes="slider-container"
                )
                temperature = gr.Slider(
                    label="🌡️ Temperature",
                    minimum=0.1,
                    maximum=1.0,
                    step=0.1,
                    value=0.1,
                    elem_classes="slider-container"
                )
            
            with gr.Row():
                top_p = gr.Slider(
                    label="🎯 Top-p (Nucleus Sampling)",
                    minimum=0.05,
                    maximum=1.0,
                    step=0.05,
                    value=0.9,
                    elem_classes="slider-container"
                )
                top_k = gr.Slider(
                    label="🔝 Top-k",
                    minimum=1,
                    maximum=100,
                    step=1,
                    value=50,
                    elem_classes="slider-container"
                )
            
            repetition_penalty = gr.Slider(
                label="🔄 Repetition Penalty",
                minimum=1.0,
                maximum=2.0,
                step=0.05,
                value=1.0,
                elem_classes="slider-container"
            )

    return (chatbot, msg, submit_btn, target_language, rating, feedback_text, 
            feedback_submit, max_new_tokens, temperature, top_p, top_k, repetition_penalty)

def user(user_message, history, target_lang):
    return "", history + [[user_message, None]]

def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty, model_type):
    user_message = history[-1][0]
    history[-1][1] = ""

    for chunk in translate_message(
        user_message, history[:-1], target_lang, max_tokens, 
        temp, top_p_val, top_k_val, rep_penalty, model_type
    ):
        history[-1][1] = chunk
        yield history

# Main Gradio interface
with gr.Blocks(css=css, title="🌍 Advanced Multilingual Translation Hub", theme=gr.themes.Soft()) as demo:
    
    gr.Markdown(
        """
        <div class="title-container">
        <h1>🌍 Advanced Multilingual Translation Hub</h1>
        <p style="font-size: 18px; margin-top: 10px;">
        Experience state-of-the-art translation with multiple AI models
        </p>
        </div>
        """,
        elem_classes="title-container"
    )
    
    # Statistics cards
    with gr.Row():
        gr.Markdown(
            '<div class="stats-card"><h3>🎯</h3><p><strong>22+</strong><br>Languages</p></div>',
            elem_classes="stats-card"
        )
        gr.Markdown(
            '<div class="stats-card"><h3>🚀</h3><p><strong>2</strong><br>AI Models</p></div>',
            elem_classes="stats-card"
        )
        gr.Markdown(
            '<div class="stats-card"><h3>⚡</h3><p><strong>Optimized</strong><br>Performance</p></div>',
            elem_classes="stats-card"
        )
        gr.Markdown(
            '<div class="stats-card"><h3>🔒</h3><p><strong>Secure</strong><br>Processing</p></div>',
            elem_classes="stats-card"
        )

    with gr.Tabs(elem_classes="model-tab") as tabs:
        with gr.TabItem("🇮🇳 IndicTrans3-Beta", elem_id="indictrans-tab"):
            indictrans_components = create_chatbot_interface("indictrans", INDIC_LANGUAGES, INDICTRANS_DESCRIPTION)
            
        with gr.TabItem("🌐 Sarvam Translate", elem_id="sarvam-tab"):
            sarvam_components = create_chatbot_interface("sarvam", SARVAM_LANGUAGES, SARVAM_DESCRIPTION)

    # Event handlers for IndicTrans
    (indictrans_chatbot, indictrans_msg, indictrans_submit, indictrans_lang, 
     indictrans_rating, indictrans_feedback, indictrans_feedback_submit, 
     indictrans_max_tokens, indictrans_temp, indictrans_top_p, 
     indictrans_top_k, indictrans_rep_penalty) = indictrans_components

    indictrans_msg.submit(
        user, [indictrans_msg, indictrans_chatbot, indictrans_lang], 
        [indictrans_msg, indictrans_chatbot], queue=False
    ).then(
        lambda *args: bot(*args, "indictrans"),
        [indictrans_chatbot, indictrans_lang, indictrans_max_tokens, 
         indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty],
        indictrans_chatbot,
    )

    indictrans_submit.click(
        user, [indictrans_msg, indictrans_chatbot, indictrans_lang], 
        [indictrans_msg, indictrans_chatbot], queue=False
    ).then(
        lambda *args: bot(*args, "indictrans"),
        [indictrans_chatbot, indictrans_lang, indictrans_max_tokens, 
         indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty],
        indictrans_chatbot,
    )

    indictrans_feedback_submit.click(
        lambda *args: store_feedback(*args, "indictrans"),
        inputs=[indictrans_rating, indictrans_feedback, indictrans_chatbot, indictrans_lang],
    )

    # Event handlers for Sarvam
    (sarvam_chatbot, sarvam_msg, sarvam_submit, sarvam_lang, 
     sarvam_rating, sarvam_feedback, sarvam_feedback_submit, 
     sarvam_max_tokens, sarvam_temp, sarvam_top_p, 
     sarvam_top_k, sarvam_rep_penalty) = sarvam_components

    sarvam_msg.submit(
        user, [sarvam_msg, sarvam_chatbot, sarvam_lang], 
        [sarvam_msg, sarvam_chatbot], queue=False
    ).then(
        lambda *args: bot(*args, "sarvam"),
        [sarvam_chatbot, sarvam_lang, sarvam_max_tokens, 
         sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty],
        sarvam_chatbot,
    )

    sarvam_submit.click(
        user, [sarvam_msg, sarvam_chatbot, sarvam_lang], 
        [sarvam_msg, sarvam_chatbot], queue=False
    ).then(
        lambda *args: bot(*args, "sarvam"),
        [sarvam_chatbot, sarvam_lang, sarvam_max_tokens, 
         sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty],
        sarvam_chatbot,
    )

    sarvam_feedback_submit.click(
        lambda *args: store_feedback(*args, "sarvam"),
        inputs=[sarvam_rating, sarvam_feedback, sarvam_chatbot, sarvam_lang],
    )

    # Footer
    gr.Markdown(
        """
        <div style="text-align: center; margin-top: 2rem; padding: 1rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; color: white;">
        <p>🚀 <strong>Powered by AI4Bharat & Sarvam AI</strong> | 
        Built with ❤️ using Gradio | 
        🔧 <strong>Optimized with KV Caching & Advanced Memory Management</strong></p>
        </div>
        """
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True,
    )