Spaces:
Sleeping
Sleeping
# --- START OF FILE app (5).py --- | |
try: | |
import spaces | |
print("'spaces' module imported successfully.") | |
except ImportError: | |
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") | |
# Define a dummy decorator that does nothing if 'spaces' isn't available | |
class DummySpaces: | |
def GPU(self, *args, **kwargs): | |
def decorator(func): | |
# This dummy decorator just returns the original function | |
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") | |
return func | |
return decorator | |
spaces = DummySpaces() # Create an instance of the dummy class | |
import gradio as gr | |
import re # Import the regular expression module | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM | |
import torch # Or import tensorflow as tf | |
import os | |
import math | |
# Requires Gradio version supporting spaces.GPU decorator if running on Spaces | |
# Might need: from gradio.external import spaces <- if spaces not directly available | |
#import gradio.external as spaces # Use this import path | |
from huggingface_hub import hf_hub_download | |
# --- Configuration --- | |
# IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path | |
MODEL_PATH = "Gregniuki/pl-en-pl" # Use your actual model path | |
MAX_WORDS_PER_CHUNK = 55 # Define the maximum words per chunk | |
BATCH_SIZE = 8 # Adjust based on GPU memory / desired throughput | |
# --- Device Setup (Zero GPU Support) --- | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print("GPU detected. Using CUDA.") | |
else: | |
device = torch.device("cpu") | |
print("No GPU detected. Using CPU.") | |
# --- Get Hugging Face Token from Secrets for Private Models --- | |
HF_AUTH_TOKEN = os.getenv("HF_TOKEN") | |
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID | |
if HF_AUTH_TOKEN is None: | |
print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.") | |
else: | |
print("HF_TOKEN found. Using token for model loading.") | |
else: | |
print(f"Loading model from local path: {MODEL_PATH}") | |
HF_AUTH_TOKEN = None # Don't use token for local paths | |
# --- Load Model and Tokenizer (once on startup) --- | |
print(f"Loading model and tokenizer from: {MODEL_PATH}") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATH, | |
token=HF_AUTH_TOKEN, | |
trust_remote_code=False | |
) | |
# --- Choose the correct model class --- | |
# PyTorch (most common) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATH, | |
token=HF_AUTH_TOKEN, | |
trust_remote_code=False | |
) | |
model.to(device) # Move model to the determined device | |
model.eval() # Set model to evaluation mode | |
print(f"Using PyTorch model on device: {device}") | |
# # TensorFlow (uncomment if your model is TF) | |
# from transformers import TFAutoModelForSeq2SeqLM | |
# import tensorflow as tf | |
# model = TFAutoModelForSeq2SeqLM.from_pretrained( | |
# MODEL_PATH, | |
# token=HF_AUTH_TOKEN, | |
# trust_remote_code=False | |
# ) | |
# # TF device placement is often automatic or managed via strategies | |
# print("Using TensorFlow model.") | |
print("Model and tokenizer loaded successfully.") | |
except Exception as e: | |
print(f"FATAL Error loading model/tokenizer: {e}") | |
if "401 Client Error" in str(e): | |
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}." | |
else: | |
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}" | |
# Raise error to prevent app launch if model loading fails | |
raise RuntimeError(error_message) | |
# --- Helper Functions for Chunking --- | |
def chunk_sentence(sentence, max_words): | |
""" | |
Splits a sentence (or line of text) into chunks ONLY if it exceeds max_words. | |
If splitting is needed, it prioritizes splitting *after* sentence-ending | |
punctuation (. ! ?) or commas (,) found within the first `max_words`. | |
It looks for the *last* such punctuation within that limit. | |
If no suitable punctuation is found, it splits strictly at `max_words`. | |
""" | |
if not sentence or sentence.isspace(): | |
return [] | |
sentence = sentence.strip() # Ensure no leading/trailing whitespace | |
words = sentence.split() | |
word_count = len(words) | |
# If the sentence is short enough, return it as a single chunk | |
if word_count <= max_words: | |
return [sentence] | |
# If the sentence is too long, proceed with chunking | |
chunks = [] | |
current_word_index = 0 | |
while current_word_index < word_count: | |
# Determine the end index for the current potential chunk (non-inclusive) | |
potential_end_word_index = min(current_word_index + max_words, word_count) | |
# Assume we split at the max_words limit initially | |
actual_end_word_index = potential_end_word_index | |
# Check if we need to look for punctuation (i.e., if this chunk would be exactly max_words | |
# and there's more text remaining, or if the remaining text itself is longer than max_words) | |
# This check ensures we don't unnecessarily truncate if the remaining part is short. | |
if potential_end_word_index < word_count: | |
# Search backwards from the word *before* the potential end index | |
# down to the start of the current segment for punctuation. | |
best_punctuation_split_index = -1 | |
for i in range(potential_end_word_index - 1, current_word_index, -1): | |
# Check if the word at index 'i' ends with the desired punctuation | |
if words[i].endswith(('.', '!', '?', ',')): | |
best_punctuation_split_index = i + 1 # Split *after* this word | |
break # Found the last suitable punctuation in the range | |
# If we found a punctuation split point, use it | |
if best_punctuation_split_index > current_word_index: # Ensure it's a valid index within the current segment | |
actual_end_word_index = best_punctuation_split_index | |
# Else: No suitable punctuation found, stick with potential_end_word_index (split at max_words limit) | |
# Safety check: Prevent creating an empty chunk if the split point is the same as the start | |
# This can happen if the first word itself is very long or under unusual circumstances. | |
# Force consuming at least one word if we are not at the end. | |
if actual_end_word_index <= current_word_index and current_word_index < word_count: | |
actual_end_word_index = current_word_index + 1 | |
print(f"Warning: Split point adjustment needed. Forced split after word index {current_word_index}.") | |
# Extract the chunk words and join them | |
chunk_words = words[current_word_index:actual_end_word_index] | |
if chunk_words: # Ensure we don't add empty strings | |
chunks.append(" ".join(chunk_words)) | |
# Update the starting index for the next chunk | |
current_word_index = actual_end_word_index | |
# Basic infinite loop prevention (should not be necessary with correct logic but safe) | |
if current_word_index == word_count and len(chunks) > 0: # Normal exit condition | |
break | |
if current_word_index < word_count and actual_end_word_index <= current_word_index : | |
print(f"ERROR: Chunking loop failed to advance. Aborting chunking for this sentence.") | |
# Return partially chunked sentence or handle error appropriately | |
# For simplicity, we might return the chunks found so far plus the rest unsplit | |
remaining_words = words[current_word_index:] | |
if remaining_words: | |
chunks.append(" ".join(remaining_words)) | |
break # Exit loop | |
return [chunk for chunk in chunks if chunk] # Final filter for empty strings | |
# --- Define the BATCH translation function --- | |
# Add GPU decorator for Spaces (adjust duration if needed) | |
def translate_batch(text_input): | |
""" | |
Translates multi-line input text using batching and sentence chunking. | |
Assumes auto-detection of language direction (no prefixes). | |
Uses the updated chunking logic. | |
""" | |
if not text_input or text_input.strip() == "": | |
return "[Error] Please enter some text to translate." | |
print(f"Received input block for batch translation.") | |
# 1. Split input into lines and clean | |
lines = [line.strip() for line in text_input.splitlines() if line.strip()] | |
if not lines: | |
return "[Info] No valid text lines found in input." | |
# 2. Chunk each line individually using the new logic | |
all_chunks = [] | |
for line in lines: | |
# Apply the new chunking logic to each line | |
line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK) | |
all_chunks.extend(line_chunks) | |
if not all_chunks: | |
return "[Info] No text chunks generated after processing input." | |
print(f"Processing {len(all_chunks)} chunks in batches...") | |
# 3. Process chunks in batches | |
all_translations = [] | |
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE) | |
for i in range(num_batches): | |
batch_start = i * BATCH_SIZE | |
batch_end = batch_start + BATCH_SIZE | |
batch_chunks = all_chunks[batch_start:batch_end] | |
print(f" Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)") | |
# Tokenize the batch | |
try: | |
inputs = tokenizer( | |
batch_chunks, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=1024 # Model's max input length | |
).to(device) | |
# Estimate appropriate max_new_tokens based on input length | |
# A simple heuristic: allow for some expansion, but cap at model max length | |
max_input_length = inputs["input_ids"].shape[1] | |
# Allow up to 20% expansion, capped at 1024 total tokens (input+output) if needed, | |
# or just a fixed reasonably large number if expansion is less predictable. | |
# Let's use a multiplier + cap for seq2seq | |
max_new_tokens = min(int(max_input_length * 1.2) + 10, 1024) # Increased multiplier for safety | |
print(f"Tokenized input (batch max length={max_input_length}), setting max_new_tokens={max_new_tokens}") | |
# Optional: print token counts per input for debugging | |
# for idx, ids in enumerate(inputs["input_ids"]): | |
# print(f" Input {idx+1}: {len(ids)} tokens for chunk: '{batch_chunks[idx][:50]}...'") | |
except Exception as e: | |
print(f"Error during batch tokenization: {e}") | |
# Consider returning partial results or a specific error | |
all_translations.append(f"[Error tokenizing batch {i+1}]") | |
continue # Skip to next batch or break | |
# Generate translations for the batch | |
try: | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=max_new_tokens, | |
num_beams=4, | |
# no_repeat_ngram_size=3, # Consider if needed for model | |
early_stopping=True, # Usually good for translation | |
# Remove output_scores unless needed for specific analysis | |
# return_dict_in_generate=True, # Keep if you use outputs.sequences | |
# output_scores=True | |
) | |
print(f" Generation completed for batch {i+1}") | |
# Use default output which is usually the sequences tensor | |
batch_translations = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
all_translations.extend(batch_translations) | |
except Exception as e: | |
print(f"Error during batch generation/decoding: {e}") | |
# Append error messages for the failed chunks in this batch | |
error_msg = f"[Error translating batch {i+1}]" | |
all_translations.extend([error_msg] * len(batch_chunks)) | |
# Consider if you want to stop processing or continue with next batches | |
# 4. Join translated chunks back together | |
# Simple join with newline. This respects that each chunk was processed independently. | |
final_output = "\n".join(all_translations) | |
print("Batch translation finished.") | |
return final_output | |
# --- Create Gradio Interface for Batch Translation --- | |
input_textbox = gr.Textbox( | |
lines=10, # Allow more lines for batch input | |
label="Input Text (Polish or English - Enter multiple lines/sentences)", | |
placeholder=f"Enter text here. Lines longer than {MAX_WORDS_PER_CHUNK} words will be split, prioritizing breaks after . ! ? , near the limit." | |
) | |
output_textbox = gr.Textbox(label="Translation Output", lines=10) | |
# Interface definition | |
interface = gr.Interface( | |
fn=translate_batch, # Use the batch function | |
inputs=input_textbox, | |
outputs=output_textbox, | |
title="🇵🇱 <-> 🇬🇧 Batch ByT5 Translator (Auto-Detect, Smart Chunking)", | |
description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is processed line by line. Lines longer than {MAX_WORDS_PER_CHUNK} words are split into chunks.", | |
# Updated Article explaining the new logic | |
article=f"Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. Each line you enter is processed independently.\n2. If a line contains {MAX_WORDS_PER_CHUNK} words or fewer, it is translated as a single unit.\n3. If a line contains more than {MAX_WORDS_PER_CHUNK} words, it is split into smaller chunks.\n4. When splitting, the algorithm looks for the last punctuation mark (. ! ? ,) within the first {MAX_WORDS_PER_CHUNK} words to use as a natural break point.\n5. If no suitable punctuation is found in that range, the line is split exactly at the {MAX_WORDS_PER_CHUNK}-word limit.\n6. This process repeats for the remainder of the line until all parts are below the word limit.\n7. These final chunks are then translated in batches.", | |
allow_flagging="never" | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
# Set share=True for a public link if running locally, not needed on Spaces | |
interface.launch() | |
# --- END OF FILE app (5).py --- |