translate2 / app.py
Gregniuki's picture
Update app.py
8b3849e verified
# --- START OF MODIFIED app.py ---
try:
import spaces
print("'spaces' module imported successfully.")
except ImportError:
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.")
class DummySpaces:
def GPU(self, *args, **kwargs):
def decorator(func):
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.")
return func
return decorator
spaces = DummySpaces()
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
import math
# --- Configuration ---
MODEL_PATH = "gregniuki/mandarin_thai_ipa"
BATCH_SIZE = 8
# --- Device Setup ---
if torch.cuda.is_available():
device = torch.device("cuda")
print("GPU detected. Using CUDA.")
else:
device = torch.device("cpu")
print("No GPU detected. Using CPU.")
# --- Model Loading ---
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN)
model.to(device)
model.eval()
print(f"Model and tokenizer loaded successfully on device: {device}")
except Exception as e:
raise RuntimeError(f"FATAL Error loading model/tokenizer: {e}")
# --- Helper Function for Chunking ---
def chunk_text(text, max_chars):
if not text or text.isspace():
return []
text = text.strip()
if len(text) <= max_chars:
return [text]
chunks, current_index = [], 0
while current_index < len(text):
potential_end_index = min(current_index + max_chars, len(text))
actual_end_index = potential_end_index
if potential_end_index < len(text):
punctuation = ".!?。!?,,"
best_split_pos = -1
for punc in punctuation:
best_split_pos = max(best_split_pos, text.rfind(punc, current_index, potential_end_index))
if best_split_pos != -1:
actual_end_index = best_split_pos + 1
chunk = text[current_index:actual_end_index]
if chunk and not chunk.isspace():
chunks.append(chunk.strip())
current_index = actual_end_index
if current_index >= len(text):
break
return [c for c in chunks if c]
# --- Main Processing Function ---
@spaces.GPU
def translate_batch(
text_input,
max_chars_per_chunk,
repetition_penalty,
token_multiplier, # --- MODIFICATION: Added new parameter for the multiplier ---
decoding_strategy,
num_beams,
length_penalty,
use_early_stopping,
temperature,
top_p,
progress=gr.Progress(track_tqdm=True)
):
if not text_input or text_input.strip() == "":
return "[Error] Please enter some text to process."
lines = [line.strip() for line in text_input.splitlines() if line.strip()]
if not lines:
return "[Info] No valid text lines found in input."
max_chars_per_chunk = int(max_chars_per_chunk)
all_chunks = []
for line in lines:
all_chunks.extend(chunk_text(line, max_chars_per_chunk))
if not all_chunks:
return "[Info] No text chunks generated after processing input."
generation_kwargs = {"repetition_penalty": repetition_penalty, "do_sample": False}
if decoding_strategy == "Beam Search":
generation_kwargs.update({
"num_beams": int(num_beams), "length_penalty": length_penalty, "early_stopping": use_early_stopping,
})
elif decoding_strategy == "Sampling":
generation_kwargs.update({
"do_sample": True, "temperature": temperature, "top_p": top_p, "num_beams": 1,
})
print(f"Processing {len(all_chunks)} chunks with strategy: {decoding_strategy}. Args: {generation_kwargs}")
print(f" Using token_multiplier: {token_multiplier}") # Log the multiplier
all_ipa_outputs = []
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE)
for i in progress.tqdm(range(num_batches), desc="Processing Batches"):
batch_start, batch_end = i * BATCH_SIZE, (i + 1) * BATCH_SIZE
batch_chunks = all_chunks[batch_start:batch_end]
try:
inputs = tokenizer(
batch_chunks, return_tensors="pt", padding=True, truncation=True, max_length=512
).to(device)
# --- MODIFICATION: Dynamic max_new_tokens using the slider value ---
max_input_length = inputs["input_ids"].shape[1]
generation_kwargs["max_new_tokens"] = min(int(max_input_length * token_multiplier) + 10, 512)
with torch.no_grad():
outputs = model.generate(**generation_kwargs, **inputs)
batch_ipa = tokenizer.batch_decode(outputs, skip_special_tokens=True)
all_ipa_outputs.extend(batch_ipa)
except Exception as e:
print(f"Error during batch {i+1} processing: {e}")
all_ipa_outputs.extend([f"[Error in batch {i+1}]"] * len(batch_chunks))
return "\n".join(all_ipa_outputs)
# --- UI Helper Function for Dynamic Controls ---
def update_decoding_ui(strategy):
if strategy == "Beam Search":
return gr.update(visible=True), gr.update(visible=False)
elif strategy == "Sampling":
return gr.update(visible=False), gr.update(visible=True)
# --- Gradio UI using Blocks ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🇹🇭🇨🇳 Advanced Mandarin & Thai to IPA Converter
Get the International Phonetic Alphabet (IPA) for Chinese (Mandarin) or Thai text. The model automatically detects the language.
This interface provides advanced controls for tuning the output quality.
"""
)
with gr.Row():
with gr.Column(scale=1):
input_textbox = gr.Textbox(lines=10, label="Input Text (Mandarin or Thai)", placeholder="Enter text here (e.g., 你好世界 or สวัสดีครับ).")
submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=1):
output_textbox = gr.Textbox(lines=10, label="IPA Output", interactive=False)
with gr.Accordion("Generation & Chunking Controls", open=False):
with gr.Row():
max_chars_slider = gr.Slider(minimum=1, maximum=512, step=1, value=36, label="Max Characters per Chunk", info="Splits long lines into smaller pieces for the model.")
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.1, label="Repetition Penalty", info=">1.0 discourages repeated words. Prevents stuttering.")
# --- MODIFICATION: Added the new slider for token multiplier ---
token_multiplier_slider = gr.Slider(
minimum=1.0, maximum=4.0, step=0.1, value=2.0,
label="Output Token Multiplier",
info="Safety factor for output length. Increase if output is cut off."
)
decoding_strategy_radio = gr.Radio(["Beam Search", "Sampling"], value="Beam Search", label="Decoding Strategy", info="Choose the method for generating text.")
with gr.Group(visible=True) as beam_search_group:
with gr.Row():
num_beams_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of Beams", info="More beams can yield better quality but are slower.")
length_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Penalty", info=">1.0 encourages longer output; <1.0 for shorter.")
early_stopping_checkbox = gr.Checkbox(value=True, label="Early Stopping", info="Stop when a sentence is complete. Recommended.")
with gr.Group(visible=False) as sampling_group:
with gr.Row():
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.8, label="Temperature", info="Controls randomness. Lower is more predictable.")
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.95, label="Top-p (Nucleus Sampling)", info="Considers a smaller, more probable set of next words.")
gr.Markdown(f"**Model:** `{MODEL_PATH}` | **Batch Size:** `{BATCH_SIZE}` | **Device:** `{str(device).upper()}`")
# --- Event Listeners ---
decoding_strategy_radio.change(fn=update_decoding_ui, inputs=decoding_strategy_radio, outputs=[beam_search_group, sampling_group])
submit_button.click(
fn=translate_batch,
inputs=[
input_textbox,
max_chars_slider,
repetition_penalty_slider,
token_multiplier_slider, # --- MODIFICATION: Added slider to inputs list ---
decoding_strategy_radio,
num_beams_slider,
length_penalty_slider,
early_stopping_checkbox,
temperature_slider,
top_p_slider,
],
outputs=output_textbox
)
# --- Launch the App ---
if __name__ == "__main__":
demo.launch()
# --- END OF FILE ---