# --- START OF MODIFIED app.py --- try: import spaces print("'spaces' module imported successfully.") except ImportError: print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") class DummySpaces: def GPU(self, *args, **kwargs): def decorator(func): print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") return func return decorator spaces = DummySpaces() import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch import os import math # --- Configuration --- MODEL_PATH = "gregniuki/mandarin_thai_ipa" BATCH_SIZE = 8 # --- Device Setup --- if torch.cuda.is_available(): device = torch.device("cuda") print("GPU detected. Using CUDA.") else: device = torch.device("cpu") print("No GPU detected. Using CPU.") # --- Model Loading --- HF_AUTH_TOKEN = os.getenv("HF_TOKEN") print(f"Loading model and tokenizer from: {MODEL_PATH}") try: tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN) model.to(device) model.eval() print(f"Model and tokenizer loaded successfully on device: {device}") except Exception as e: raise RuntimeError(f"FATAL Error loading model/tokenizer: {e}") # --- Helper Function for Chunking --- def chunk_text(text, max_chars): if not text or text.isspace(): return [] text = text.strip() if len(text) <= max_chars: return [text] chunks, current_index = [], 0 while current_index < len(text): potential_end_index = min(current_index + max_chars, len(text)) actual_end_index = potential_end_index if potential_end_index < len(text): punctuation = ".!?。!?,," best_split_pos = -1 for punc in punctuation: best_split_pos = max(best_split_pos, text.rfind(punc, current_index, potential_end_index)) if best_split_pos != -1: actual_end_index = best_split_pos + 1 chunk = text[current_index:actual_end_index] if chunk and not chunk.isspace(): chunks.append(chunk.strip()) current_index = actual_end_index if current_index >= len(text): break return [c for c in chunks if c] # --- Main Processing Function --- @spaces.GPU def translate_batch( text_input, max_chars_per_chunk, repetition_penalty, token_multiplier, # --- MODIFICATION: Added new parameter for the multiplier --- decoding_strategy, num_beams, length_penalty, use_early_stopping, temperature, top_p, progress=gr.Progress(track_tqdm=True) ): if not text_input or text_input.strip() == "": return "[Error] Please enter some text to process." lines = [line.strip() for line in text_input.splitlines() if line.strip()] if not lines: return "[Info] No valid text lines found in input." max_chars_per_chunk = int(max_chars_per_chunk) all_chunks = [] for line in lines: all_chunks.extend(chunk_text(line, max_chars_per_chunk)) if not all_chunks: return "[Info] No text chunks generated after processing input." generation_kwargs = {"repetition_penalty": repetition_penalty, "do_sample": False} if decoding_strategy == "Beam Search": generation_kwargs.update({ "num_beams": int(num_beams), "length_penalty": length_penalty, "early_stopping": use_early_stopping, }) elif decoding_strategy == "Sampling": generation_kwargs.update({ "do_sample": True, "temperature": temperature, "top_p": top_p, "num_beams": 1, }) print(f"Processing {len(all_chunks)} chunks with strategy: {decoding_strategy}. Args: {generation_kwargs}") print(f" Using token_multiplier: {token_multiplier}") # Log the multiplier all_ipa_outputs = [] num_batches = math.ceil(len(all_chunks) / BATCH_SIZE) for i in progress.tqdm(range(num_batches), desc="Processing Batches"): batch_start, batch_end = i * BATCH_SIZE, (i + 1) * BATCH_SIZE batch_chunks = all_chunks[batch_start:batch_end] try: inputs = tokenizer( batch_chunks, return_tensors="pt", padding=True, truncation=True, max_length=512 ).to(device) # --- MODIFICATION: Dynamic max_new_tokens using the slider value --- max_input_length = inputs["input_ids"].shape[1] generation_kwargs["max_new_tokens"] = min(int(max_input_length * token_multiplier) + 10, 512) with torch.no_grad(): outputs = model.generate(**generation_kwargs, **inputs) batch_ipa = tokenizer.batch_decode(outputs, skip_special_tokens=True) all_ipa_outputs.extend(batch_ipa) except Exception as e: print(f"Error during batch {i+1} processing: {e}") all_ipa_outputs.extend([f"[Error in batch {i+1}]"] * len(batch_chunks)) return "\n".join(all_ipa_outputs) # --- UI Helper Function for Dynamic Controls --- def update_decoding_ui(strategy): if strategy == "Beam Search": return gr.update(visible=True), gr.update(visible=False) elif strategy == "Sampling": return gr.update(visible=False), gr.update(visible=True) # --- Gradio UI using Blocks --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🇹🇭🇨🇳 Advanced Mandarin & Thai to IPA Converter Get the International Phonetic Alphabet (IPA) for Chinese (Mandarin) or Thai text. The model automatically detects the language. This interface provides advanced controls for tuning the output quality. """ ) with gr.Row(): with gr.Column(scale=1): input_textbox = gr.Textbox(lines=10, label="Input Text (Mandarin or Thai)", placeholder="Enter text here (e.g., 你好世界 or สวัสดีครับ).") submit_button = gr.Button("Submit", variant="primary") with gr.Column(scale=1): output_textbox = gr.Textbox(lines=10, label="IPA Output", interactive=False) with gr.Accordion("Generation & Chunking Controls", open=False): with gr.Row(): max_chars_slider = gr.Slider(minimum=1, maximum=512, step=1, value=36, label="Max Characters per Chunk", info="Splits long lines into smaller pieces for the model.") repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.1, label="Repetition Penalty", info=">1.0 discourages repeated words. Prevents stuttering.") # --- MODIFICATION: Added the new slider for token multiplier --- token_multiplier_slider = gr.Slider( minimum=1.0, maximum=4.0, step=0.1, value=2.0, label="Output Token Multiplier", info="Safety factor for output length. Increase if output is cut off." ) decoding_strategy_radio = gr.Radio(["Beam Search", "Sampling"], value="Beam Search", label="Decoding Strategy", info="Choose the method for generating text.") with gr.Group(visible=True) as beam_search_group: with gr.Row(): num_beams_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of Beams", info="More beams can yield better quality but are slower.") length_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Penalty", info=">1.0 encourages longer output; <1.0 for shorter.") early_stopping_checkbox = gr.Checkbox(value=True, label="Early Stopping", info="Stop when a sentence is complete. Recommended.") with gr.Group(visible=False) as sampling_group: with gr.Row(): temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.8, label="Temperature", info="Controls randomness. Lower is more predictable.") top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.95, label="Top-p (Nucleus Sampling)", info="Considers a smaller, more probable set of next words.") gr.Markdown(f"**Model:** `{MODEL_PATH}` | **Batch Size:** `{BATCH_SIZE}` | **Device:** `{str(device).upper()}`") # --- Event Listeners --- decoding_strategy_radio.change(fn=update_decoding_ui, inputs=decoding_strategy_radio, outputs=[beam_search_group, sampling_group]) submit_button.click( fn=translate_batch, inputs=[ input_textbox, max_chars_slider, repetition_penalty_slider, token_multiplier_slider, # --- MODIFICATION: Added slider to inputs list --- decoding_strategy_radio, num_beams_slider, length_penalty_slider, early_stopping_checkbox, temperature_slider, top_p_slider, ], outputs=output_textbox ) # --- Launch the App --- if __name__ == "__main__": demo.launch() # --- END OF FILE ---