Spaces:
Sleeping
Sleeping
# --- START OF MODIFIED app.py --- | |
try: | |
import spaces | |
print("'spaces' module imported successfully.") | |
except ImportError: | |
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") | |
class DummySpaces: | |
def GPU(self, *args, **kwargs): | |
def decorator(func): | |
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") | |
return func | |
return decorator | |
spaces = DummySpaces() | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
import os | |
import math | |
# --- Configuration --- | |
MODEL_PATH = "gregniuki/mandarin_thai_ipa" | |
BATCH_SIZE = 8 | |
# --- Device Setup --- | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print("GPU detected. Using CUDA.") | |
else: | |
device = torch.device("cpu") | |
print("No GPU detected. Using CPU.") | |
# --- Model Loading --- | |
HF_AUTH_TOKEN = os.getenv("HF_TOKEN") | |
print(f"Loading model and tokenizer from: {MODEL_PATH}") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_PATH, token=HF_AUTH_TOKEN) | |
model.to(device) | |
model.eval() | |
print(f"Model and tokenizer loaded successfully on device: {device}") | |
except Exception as e: | |
raise RuntimeError(f"FATAL Error loading model/tokenizer: {e}") | |
# --- Helper Function for Chunking --- | |
def chunk_text(text, max_chars): | |
if not text or text.isspace(): | |
return [] | |
text = text.strip() | |
if len(text) <= max_chars: | |
return [text] | |
chunks, current_index = [], 0 | |
while current_index < len(text): | |
potential_end_index = min(current_index + max_chars, len(text)) | |
actual_end_index = potential_end_index | |
if potential_end_index < len(text): | |
punctuation = ".!?。!?,," | |
best_split_pos = -1 | |
for punc in punctuation: | |
best_split_pos = max(best_split_pos, text.rfind(punc, current_index, potential_end_index)) | |
if best_split_pos != -1: | |
actual_end_index = best_split_pos + 1 | |
chunk = text[current_index:actual_end_index] | |
if chunk and not chunk.isspace(): | |
chunks.append(chunk.strip()) | |
current_index = actual_end_index | |
if current_index >= len(text): | |
break | |
return [c for c in chunks if c] | |
# --- Main Processing Function --- | |
def translate_batch( | |
text_input, | |
max_chars_per_chunk, | |
repetition_penalty, | |
token_multiplier, # --- MODIFICATION: Added new parameter for the multiplier --- | |
decoding_strategy, | |
num_beams, | |
length_penalty, | |
use_early_stopping, | |
temperature, | |
top_p, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
if not text_input or text_input.strip() == "": | |
return "[Error] Please enter some text to process." | |
lines = [line.strip() for line in text_input.splitlines() if line.strip()] | |
if not lines: | |
return "[Info] No valid text lines found in input." | |
max_chars_per_chunk = int(max_chars_per_chunk) | |
all_chunks = [] | |
for line in lines: | |
all_chunks.extend(chunk_text(line, max_chars_per_chunk)) | |
if not all_chunks: | |
return "[Info] No text chunks generated after processing input." | |
generation_kwargs = {"repetition_penalty": repetition_penalty, "do_sample": False} | |
if decoding_strategy == "Beam Search": | |
generation_kwargs.update({ | |
"num_beams": int(num_beams), "length_penalty": length_penalty, "early_stopping": use_early_stopping, | |
}) | |
elif decoding_strategy == "Sampling": | |
generation_kwargs.update({ | |
"do_sample": True, "temperature": temperature, "top_p": top_p, "num_beams": 1, | |
}) | |
print(f"Processing {len(all_chunks)} chunks with strategy: {decoding_strategy}. Args: {generation_kwargs}") | |
print(f" Using token_multiplier: {token_multiplier}") # Log the multiplier | |
all_ipa_outputs = [] | |
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE) | |
for i in progress.tqdm(range(num_batches), desc="Processing Batches"): | |
batch_start, batch_end = i * BATCH_SIZE, (i + 1) * BATCH_SIZE | |
batch_chunks = all_chunks[batch_start:batch_end] | |
try: | |
inputs = tokenizer( | |
batch_chunks, return_tensors="pt", padding=True, truncation=True, max_length=512 | |
).to(device) | |
# --- MODIFICATION: Dynamic max_new_tokens using the slider value --- | |
max_input_length = inputs["input_ids"].shape[1] | |
generation_kwargs["max_new_tokens"] = min(int(max_input_length * token_multiplier) + 10, 512) | |
with torch.no_grad(): | |
outputs = model.generate(**generation_kwargs, **inputs) | |
batch_ipa = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
all_ipa_outputs.extend(batch_ipa) | |
except Exception as e: | |
print(f"Error during batch {i+1} processing: {e}") | |
all_ipa_outputs.extend([f"[Error in batch {i+1}]"] * len(batch_chunks)) | |
return "\n".join(all_ipa_outputs) | |
# --- UI Helper Function for Dynamic Controls --- | |
def update_decoding_ui(strategy): | |
if strategy == "Beam Search": | |
return gr.update(visible=True), gr.update(visible=False) | |
elif strategy == "Sampling": | |
return gr.update(visible=False), gr.update(visible=True) | |
# --- Gradio UI using Blocks --- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# 🇹🇭🇨🇳 Advanced Mandarin & Thai to IPA Converter | |
Get the International Phonetic Alphabet (IPA) for Chinese (Mandarin) or Thai text. The model automatically detects the language. | |
This interface provides advanced controls for tuning the output quality. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=1): | |
input_textbox = gr.Textbox(lines=10, label="Input Text (Mandarin or Thai)", placeholder="Enter text here (e.g., 你好世界 or สวัสดีครับ).") | |
submit_button = gr.Button("Submit", variant="primary") | |
with gr.Column(scale=1): | |
output_textbox = gr.Textbox(lines=10, label="IPA Output", interactive=False) | |
with gr.Accordion("Generation & Chunking Controls", open=False): | |
with gr.Row(): | |
max_chars_slider = gr.Slider(minimum=1, maximum=512, step=1, value=36, label="Max Characters per Chunk", info="Splits long lines into smaller pieces for the model.") | |
repetition_penalty_slider = gr.Slider(minimum=1.0, maximum=2.0, step=0.05, value=1.1, label="Repetition Penalty", info=">1.0 discourages repeated words. Prevents stuttering.") | |
# --- MODIFICATION: Added the new slider for token multiplier --- | |
token_multiplier_slider = gr.Slider( | |
minimum=1.0, maximum=4.0, step=0.1, value=2.0, | |
label="Output Token Multiplier", | |
info="Safety factor for output length. Increase if output is cut off." | |
) | |
decoding_strategy_radio = gr.Radio(["Beam Search", "Sampling"], value="Beam Search", label="Decoding Strategy", info="Choose the method for generating text.") | |
with gr.Group(visible=True) as beam_search_group: | |
with gr.Row(): | |
num_beams_slider = gr.Slider(minimum=1, maximum=10, step=1, value=4, label="Number of Beams", info="More beams can yield better quality but are slower.") | |
length_penalty_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Length Penalty", info=">1.0 encourages longer output; <1.0 for shorter.") | |
early_stopping_checkbox = gr.Checkbox(value=True, label="Early Stopping", info="Stop when a sentence is complete. Recommended.") | |
with gr.Group(visible=False) as sampling_group: | |
with gr.Row(): | |
temperature_slider = gr.Slider(minimum=0.1, maximum=1.5, step=0.05, value=0.8, label="Temperature", info="Controls randomness. Lower is more predictable.") | |
top_p_slider = gr.Slider(minimum=0.1, maximum=1.0, step=0.05, value=0.95, label="Top-p (Nucleus Sampling)", info="Considers a smaller, more probable set of next words.") | |
gr.Markdown(f"**Model:** `{MODEL_PATH}` | **Batch Size:** `{BATCH_SIZE}` | **Device:** `{str(device).upper()}`") | |
# --- Event Listeners --- | |
decoding_strategy_radio.change(fn=update_decoding_ui, inputs=decoding_strategy_radio, outputs=[beam_search_group, sampling_group]) | |
submit_button.click( | |
fn=translate_batch, | |
inputs=[ | |
input_textbox, | |
max_chars_slider, | |
repetition_penalty_slider, | |
token_multiplier_slider, # --- MODIFICATION: Added slider to inputs list --- | |
decoding_strategy_radio, | |
num_beams_slider, | |
length_penalty_slider, | |
early_stopping_checkbox, | |
temperature_slider, | |
top_p_slider, | |
], | |
outputs=output_textbox | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
demo.launch() | |
# --- END OF FILE --- |