import os import torch import spaces import psycopg2 import gradio as gr from threading import Thread from collections.abc import Iterator from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import gc # Constants MAX_MAX_NEW_TOKENS = 4096 MAX_INPUT_TOKEN_LENGTH = 4096 DEFAULT_MAX_NEW_TOKENS = 2048 HF_TOKEN = os.environ.get("HF_TOKEN", "") # Language lists INDIC_LANGUAGES = [ "Hindi", "Bengali", "Telugu", "Marathi", "Tamil", "Urdu", "Gujarati", "Kannada", "Odia", "Malayalam", "Punjabi", "Assamese", "Maithili", "Santali", "Kashmiri", "Nepali", "Sindhi", "Konkani", "Dogri", "Manipuri", "Bodo", "English", "Sanskrit" ] SARVAM_LANGUAGES = INDIC_LANGUAGES # Model configurations with optimizations TORCH_DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 DEVICE_MAP = "cuda:0" if torch.cuda.is_available() else "cpu" indictrans_model = AutoModelForCausalLM.from_pretrained( "ai4bharat/IndicTrans3-beta", torch_dtype=TORCH_DTYPE, device_map=DEVICE_MAP, token=HF_TOKEN, low_cpu_mem_usage=True, trust_remote_code=True ) sarvam_model = AutoModelForCausalLM.from_pretrained( "sarvamai/sarvam-translate", torch_dtype=TORCH_DTYPE, device_map=DEVICE_MAP, token=HF_TOKEN, low_cpu_mem_usage=True, trust_remote_code=True ) tokenizer = AutoTokenizer.from_pretrained( "ai4bharat/IndicTrans3-beta", trust_remote_code=True ) def format_message_for_translation(message, target_lang): return f"Translate the following text to {target_lang}: {message}" def store_feedback(rating, feedback_text, chat_history, tgt_lang, model_type): try: if not rating: gr.Warning("Please select a rating before submitting feedback.", duration=5) return None if not feedback_text or feedback_text.strip() == "": gr.Warning("Please provide some feedback before submitting.", duration=5) return None if not chat_history: gr.Warning("Please provide the input text before submitting feedback.", duration=5) return None if len(chat_history[0]) < 2: gr.Warning("Please translate the input text before submitting feedback.", duration=5) return None conn = psycopg2.connect( host=os.getenv("DB_HOST"), database=os.getenv("DB_NAME"), user=os.getenv("DB_USER"), password=os.getenv("DB_PASSWORD"), port=os.getenv("DB_PORT"), ) cursor = conn.cursor() insert_query = """ INSERT INTO feedback (tgt_lang, rating, feedback_txt, chat_history, model_type) VALUES (%s, %s, %s, %s, %s) """ cursor.execute(insert_query, (tgt_lang, int(rating), feedback_text, chat_history, model_type)) conn.commit() cursor.close() conn.close() gr.Info("Thank you for your feedback! 🙏", duration=5) except Exception as e: print(f"Database error: {e}") gr.Error("An error occurred while storing feedback. Please try again later.", duration=5) def store_output(tgt_lang, input_text, output_text, model_type): try: conn = psycopg2.connect( host=os.getenv("DB_HOST"), database=os.getenv("DB_NAME"), user=os.getenv("DB_USER"), password=os.getenv("DB_PASSWORD"), port=os.getenv("DB_PORT"), ) cursor = conn.cursor() insert_query = """ INSERT INTO translation (input_txt, output_txt, tgt_lang, model_type) VALUES (%s, %s, %s, %s) """ cursor.execute(insert_query, (input_text, output_text, tgt_lang, model_type)) conn.commit() cursor.close() conn.close() except Exception as e: print(f"Database error: {e}") @spaces.GPU def translate_message( message: str, chat_history: list[dict], target_language: str = "Hindi", max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, model_type: str = "indictrans" ) -> Iterator[str]: if model_type == "indictrans": model = indictrans_model elif model_type == "sarvam": model = sarvam_model if model is None or tokenizer is None: yield "Error: Model failed to load. Please try again." return conversation = [] translation_request = format_message_for_translation(message, target_language) conversation.append({"role": "user", "content": translation_request}) try: input_ids = tokenizer.apply_chat_template( conversation, return_tensors="pt", add_generation_prompt=True ) if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer( tokenizer, timeout=240.0, skip_prompt=True, skip_special_tokens=True ) generate_kwargs = { "input_ids": input_ids, "streamer": streamer, "max_new_tokens": max_new_tokens, "do_sample": True, "top_p": top_p, "top_k": top_k, "temperature": temperature, "num_beams": 1, "repetition_penalty": repetition_penalty, "use_cache": True, # Enable KV cache } t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) # Clean up if torch.cuda.is_available(): torch.cuda.empty_cache() gc.collect() store_output(target_language, message, "".join(outputs), model_type) except Exception as e: yield f"Translation error: {str(e)}" # Enhanced CSS with beautiful styling css = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap'); * { font-family: 'Inter', sans-serif; box-sizing: border-box; } .gradio-container { background: #1a1a1a !important; color: #e0e0e0; min-height: 100vh; } .main-container { background: #2a2a2a; border-radius: 12px; padding: 1.5rem; margin: 1rem; box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); } .title-container { text-align: center; margin-bottom: 1.5rem; padding: 1rem; color: #a0a0ff; } .model-tab { background: #3333a0; border: none; border-radius: 8px; color: #ffffff; font-weight: 500; padding: 0.75rem 1.5rem; transition: all 0.2s ease; } .model-tab:hover { background: #4444b0; transform: translateY(-1px); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); } .language-dropdown { background: #333333; border: 1px solid #444444; border-radius: 8px; padding: 0.5rem; font-size: 14px; color: #e0e0e0; transition: all 0.2s ease; } .language-dropdown:focus { border-color: #6666ff; box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2); } .chat-container { background: #222222; border-radius: 8px; padding: 1rem; box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); margin: 1rem 0; } .message-input { background: #333333; border: 1px solid #444444; border-radius: 8px; padding: 0.75rem; font-size: 14px; color: #e0e0e0; transition: all 0.2s ease; } .message-input:focus { border-color: #6666ff; box-shadow: 0 0 0 2px rgba(102, 102, 255, 0.2); } .translate-btn { background: #3333a0; border: none; border-radius: 8px; color: #ffffff; font-weight: 500; padding: 0.75rem 1.5rem; font-size: 14px; cursor: pointer; transition: all 0.2s ease; } .translate-btn:hover { background: #4444b0; transform: translateY(-1px); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); } .examples-container { background: #2a2a2a; border-radius: 8px; padding: 1rem; margin: 1rem 0; } .feedback-section { background: #2a2a2a; border-radius: 8px; padding: 1rem; margin: 1rem 0; border: none; } .advanced-options { background: #2a2a2a; border-radius: 8px; padding: 1rem; margin: 1rem 0; } .slider-container .gr-slider { background: #444444; color: #e0e0e0; } .rating-container { display: flex; gap: 0.5rem; justify-content: center; margin: 0.5rem 0; } .feedback-btn { background: #3333a0; border: none; border-radius: 8px; color: #ffffff; font-weight: 500; padding: 0.5rem 1rem; cursor: pointer; transition: all 0.2s ease; } .feedback-btn:hover { background: #4444b0; transform: translateY(-1px); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.4); } .stats-card { background: #333333; border-radius: 8px; padding: 0.75rem; text-align: center; box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); margin: 0.5rem; color: #e0e0e0; } .model-info { background: #3333a0; color: #ffffff; border-radius: 8px; padding: 1rem; margin: 1rem 0; } .animate-pulse { animation: pulse 2s cubic-bezier(0.4, 0, 0.6, 1) infinite; } @keyframes pulse { 0%, 100% { opacity: 1; } 50% { opacity: 0.5; } } .loading-spinner { border: 3px solid #444444; border-top: 3px solid #6666ff; border-radius: 50%; width: 30px; height: 30px; animation: spin 1.5s linear infinite; margin: 0 auto; } @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } """ # Model descriptions INDICTRANS_DESCRIPTION = """ <div class="model-info"> <h3>🌟 IndicTrans3-Beta</h3> <p><strong>Latest SOTA translation model from AI4Bharat</strong></p> <ul> <li>✅ Supports <strong>22 Indic languages</strong></li> <li>✅ Document-level machine translation</li> <li>✅ Optimized for real-world applications</li> <li>✅ Enhanced with KV caching for faster inference</li> </ul> </div> """ SARVAM_DESCRIPTION = """ <div class="model-info"> <h3>🚀 Sarvam Translate</h3> <p><strong>Advanced multilingual translation model</strong></p> <ul> <li>✅ Supports <strong>22 Indic languages</strong></li> <li>✅ High-quality translations</li> <li>✅ Document-level machine translation</li> <li>✅ Optimized for real-world applications</li> <li>✅ Optimized for production use</li> <li>✅ Enhanced with KV caching for faster inference</li> </ul> </div> """ def create_chatbot_interface(model_type, languages, description): with gr.Column(elem_classes="main-container"): gr.Markdown(description) target_language = gr.Dropdown( languages, value=languages[0], label="🌍 Select Target Language", elem_classes="language-dropdown", ) chatbot = gr.Chatbot( height=500, elem_classes="chat-container", show_copy_button=True, avatar_images=["avatars/user_logo.png", "avatars/ai4bharat_logo.png"], bubble_full_width=False, show_label=False ) with gr.Row(): msg = gr.Textbox( placeholder="✍️ Enter text to translate...", show_label=False, container=False, scale=9, elem_classes="message-input", ) submit_btn = gr.Button( "🔄 Translate", scale=1, elem_classes="translate-btn" ) # Examples section if model_type == "indictrans": examples_data = [ "The Taj Mahal, an architectural marvel of white marble, stands majestically along the banks of the Yamuna River in Agra, India.", "Kumbh Mela, the world's largest spiritual gathering, is a significant Hindu festival held at four sacred riverbanks.", "India's classical dance forms, such as Bharatanatyam, Kathak, Odissi, are deeply rooted in tradition and storytelling.", "Ayurveda, India's ancient medical system, emphasizes a holistic approach to health by balancing mind, body, and spirit.", "Diwali, the festival of lights, symbolizes the victory of light over darkness and good over evil." ] else: examples_data = [ "Hello, how are you today?", "I love learning new languages and cultures.", "Technology is transforming the way we communicate.", "The weather is beautiful today.", "Thank you for your help and support." ] with gr.Accordion("📚 Example Texts", open=False, elem_classes="examples-container"): gr.Examples( examples=examples_data, inputs=msg, label="Click on any example to try:" ) # Feedback section with gr.Accordion("💭 Provide Feedback", open=False, elem_classes="feedback-section"): gr.Markdown("### 📝 Rate Translation & Share Feedback") gr.Markdown("Help us improve translation quality with your valuable feedback!") with gr.Row(): rating = gr.Radio( ["1", "2", "3", "4", "5"], label="🏆 Translation Quality Rating", value=None ) feedback_text = gr.Textbox( placeholder="💬 Share your thoughts about the translation quality, accuracy, or suggestions for improvement...", label="📝 Your Feedback", lines=3, ) feedback_submit = gr.Button( "📤 Submit Feedback", elem_classes="feedback-btn" ) # Advanced options with gr.Accordion("⚙️ Advanced Settings", open=False, elem_classes="advanced-options"): gr.Markdown("### 🔧 Fine-tune Translation Parameters") with gr.Row(): max_new_tokens = gr.Slider( label="📏 Max New Tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, elem_classes="slider-container" ) temperature = gr.Slider( label="🌡️ Temperature", minimum=0.1, maximum=1.0, step=0.1, value=0.1, elem_classes="slider-container" ) with gr.Row(): top_p = gr.Slider( label="🎯 Top-p (Nucleus Sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9, elem_classes="slider-container" ) top_k = gr.Slider( label="🔝 Top-k", minimum=1, maximum=100, step=1, value=50, elem_classes="slider-container" ) repetition_penalty = gr.Slider( label="🔄 Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0, elem_classes="slider-container" ) return (chatbot, msg, submit_btn, target_language, rating, feedback_text, feedback_submit, max_new_tokens, temperature, top_p, top_k, repetition_penalty) def user(user_message, history, target_lang): return "", history + [[user_message, None]] def bot(history, target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty, model_type): user_message = history[-1][0] history[-1][1] = "" for chunk in translate_message( user_message, history[:-1], target_lang, max_tokens, temp, top_p_val, top_k_val, rep_penalty, model_type ): history[-1][1] = chunk yield history # Main Gradio interface with gr.Blocks(css=css, title="🌍 Advanced Multilingual Translation Hub", theme=gr.themes.Soft()) as demo: gr.Markdown( """ <div class="title-container"> <h1>🌍 Advanced Multilingual Translation Hub</h1> <p style="font-size: 18px; margin-top: 10px;"> Experience state-of-the-art translation with multiple AI models </p> </div> """, elem_classes="title-container" ) # Statistics cards with gr.Row(): gr.Markdown( '<div class="stats-card"><h3>🎯</h3><p><strong>22+</strong><br>Languages</p></div>', elem_classes="stats-card" ) gr.Markdown( '<div class="stats-card"><h3>🚀</h3><p><strong>2</strong><br>AI Models</p></div>', elem_classes="stats-card" ) gr.Markdown( '<div class="stats-card"><h3>⚡</h3><p><strong>Optimized</strong><br>Performance</p></div>', elem_classes="stats-card" ) gr.Markdown( '<div class="stats-card"><h3>🔒</h3><p><strong>Secure</strong><br>Processing</p></div>', elem_classes="stats-card" ) with gr.Tabs(elem_classes="model-tab") as tabs: with gr.TabItem("🇮🇳 IndicTrans3-Beta", elem_id="indictrans-tab"): indictrans_components = create_chatbot_interface("indictrans", INDIC_LANGUAGES, INDICTRANS_DESCRIPTION) with gr.TabItem("🌐 Sarvam Translate", elem_id="sarvam-tab"): sarvam_components = create_chatbot_interface("sarvam", SARVAM_LANGUAGES, SARVAM_DESCRIPTION) # Event handlers for IndicTrans (indictrans_chatbot, indictrans_msg, indictrans_submit, indictrans_lang, indictrans_rating, indictrans_feedback, indictrans_feedback_submit, indictrans_max_tokens, indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty) = indictrans_components indictrans_msg.submit( user, [indictrans_msg, indictrans_chatbot, indictrans_lang], [indictrans_msg, indictrans_chatbot], queue=False ).then( lambda *args: bot(*args, "indictrans"), [indictrans_chatbot, indictrans_lang, indictrans_max_tokens, indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty], indictrans_chatbot, ) indictrans_submit.click( user, [indictrans_msg, indictrans_chatbot, indictrans_lang], [indictrans_msg, indictrans_chatbot], queue=False ).then( lambda *args: bot(*args, "indictrans"), [indictrans_chatbot, indictrans_lang, indictrans_max_tokens, indictrans_temp, indictrans_top_p, indictrans_top_k, indictrans_rep_penalty], indictrans_chatbot, ) indictrans_feedback_submit.click( lambda *args: store_feedback(*args, "indictrans"), inputs=[indictrans_rating, indictrans_feedback, indictrans_chatbot, indictrans_lang], ) # Event handlers for Sarvam (sarvam_chatbot, sarvam_msg, sarvam_submit, sarvam_lang, sarvam_rating, sarvam_feedback, sarvam_feedback_submit, sarvam_max_tokens, sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty) = sarvam_components sarvam_msg.submit( user, [sarvam_msg, sarvam_chatbot, sarvam_lang], [sarvam_msg, sarvam_chatbot], queue=False ).then( lambda *args: bot(*args, "sarvam"), [sarvam_chatbot, sarvam_lang, sarvam_max_tokens, sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty], sarvam_chatbot, ) sarvam_submit.click( user, [sarvam_msg, sarvam_chatbot, sarvam_lang], [sarvam_msg, sarvam_chatbot], queue=False ).then( lambda *args: bot(*args, "sarvam"), [sarvam_chatbot, sarvam_lang, sarvam_max_tokens, sarvam_temp, sarvam_top_p, sarvam_top_k, sarvam_rep_penalty], sarvam_chatbot, ) sarvam_feedback_submit.click( lambda *args: store_feedback(*args, "sarvam"), inputs=[sarvam_rating, sarvam_feedback, sarvam_chatbot, sarvam_lang], ) # Footer gr.Markdown( """ <div style="text-align: center; margin-top: 2rem; padding: 1rem; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 15px; color: white;"> <p>🚀 <strong>Powered by AI4Bharat & Sarvam AI</strong> | Built with ❤️ using Gradio | 🔧 <strong>Optimized with KV Caching & Advanced Memory Management</strong></p> </div> """ ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, show_error=True, )