try: import spaces print("'spaces' module imported successfully.") except ImportError: print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") # Define a dummy decorator that does nothing if 'spaces' isn't available class DummySpaces: def GPU(self, *args, **kwargs): def decorator(func): # This dummy decorator just returns the original function print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") return func return decorator spaces = DummySpaces() # Create an instance of the dummy class import gradio as gr import re # Import the regular expression module from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM import torch # Or import tensorflow as tf import os import math # Requires Gradio version supporting spaces.GPU decorator if running on Spaces # Might need: from gradio.external import spaces <- if spaces not directly available #import gradio.external as spaces # Use this import path from huggingface_hub import hf_hub_download # --- Configuration --- # IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path MODEL_PATH = "Gregniuki/pl-en-pl" # Use your actual model path MAX_WORDS_PER_CHUNK = 44 # Define the maximum words per chunk BATCH_SIZE = 8 # Adjust based on GPU memory / desired throughput # --- Device Setup (Zero GPU Support) --- if torch.cuda.is_available(): device = torch.device("cuda") print("GPU detected. Using CUDA.") else: device = torch.device("cpu") print("No GPU detected. Using CPU.") # --- Get Hugging Face Token from Secrets for Private Models --- HF_AUTH_TOKEN = os.getenv("HF_TOKEN") if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID if HF_AUTH_TOKEN is None: print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.") else: print("HF_TOKEN found. Using token for model loading.") else: print(f"Loading model from local path: {MODEL_PATH}") HF_AUTH_TOKEN = None # Don't use token for local paths # --- Load Model and Tokenizer (once on startup) --- print(f"Loading model and tokenizer from: {MODEL_PATH}") try: tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH, token=HF_AUTH_TOKEN, trust_remote_code=False ) # --- Choose the correct model class --- # PyTorch (most common) model = AutoModelForSeq2SeqLM.from_pretrained( MODEL_PATH, token=HF_AUTH_TOKEN, trust_remote_code=False ) model.to(device) # Move model to the determined device model.eval() # Set model to evaluation mode print(f"Using PyTorch model on device: {device}") # # TensorFlow (uncomment if your model is TF) # from transformers import TFAutoModelForSeq2SeqLM # import tensorflow as tf # model = TFAutoModelForSeq2SeqLM.from_pretrained( # MODEL_PATH, # token=HF_AUTH_TOKEN, # trust_remote_code=False # ) # # TF device placement is often automatic or managed via strategies # print("Using TensorFlow model.") print("Model and tokenizer loaded successfully.") except Exception as e: print(f"FATAL Error loading model/tokenizer: {e}") if "401 Client Error" in str(e): error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}." else: error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}" # Raise error to prevent app launch if model loading fails raise RuntimeError(error_message) # --- Helper Functions for Chunking --- def split_long_segment_by_comma_or_fallback(segment, max_words): """ Splits a long segment (already known > max_words) primarily by commas, falling back to simple word splitting if needed. """ if not segment or segment.isspace(): return [] # 1. Attempt to split by commas, keeping the comma and trailing whitespace # re.split splits *after* the pattern. (?<=,) looks behind for a comma. \s* matches trailing whitespace. comma_parts = re.split(r'(?<=,)\s*', segment) comma_parts = [p.strip() for p in comma_parts if p.strip()] # Trim and filter empty parts # If no commas found or splitting yielded strange results, fall back to word splitting if not comma_parts or (len(comma_parts) == 1 and len(comma_parts[0].split()) > max_words): # print(f"Debug: Falling back to word split for segment: '{segment[:100]}...'") # Optional debug # Fallback: Simple word-based chunking words = segment.split() segment_chunks = [] current_chunk_words = [] for word in words: current_chunk_words.append(word) # If adding the current word makes the chunk too long, finalize the previous words # and start a new chunk with the current word. if len(current_chunk_words) > max_words: # Add the chunk excluding the word that pushed it over segment_chunks.append(" ".join(current_chunk_words[:-1])) # Start a new chunk with the word that pushed it over current_chunk_words = [word] # Edge case: If the chunk is exactly max_words, finalize it unless it's the very first word. # This prevents a single chunk from staying at max_words forever if no further breaks are found. elif len(current_chunk_words) == max_words: segment_chunks.append(" ".join(current_chunk_words)) current_chunk_words = [] # Add any remaining words if current_chunk_words: segment_chunks.append(" ".join(current_chunk_words)) return segment_chunks # 2. Recombine comma-separated parts, respecting max_words segment_chunks = [] current_chunk_parts = [] # List to hold comma-separated strings for the current chunk current_chunk_word_count = 0 for i, part in enumerate(comma_parts): part_word_count = len(part.split()) # Check if adding this part makes the current chunk exceed max_words. # Condition `current_chunk_word_count > 0` ensures we don't break before adding the first part. # If the first part itself is > max_words, the fallback above handles it. if current_chunk_word_count > 0 and (current_chunk_word_count + part_word_count > max_words): # Finalize the current chunk (join the collected parts) segment_chunks.append(" ".join(current_chunk_parts).strip()) # Join with space, trim result # Start a new chunk with the current part current_chunk_parts = [part] current_chunk_word_count = part_word_count else: # Add the part to the current chunk current_chunk_parts.append(part) current_chunk_word_count += part_word_count # Add any remaining parts as the last chunk for this segment if current_chunk_parts: segment_chunks.append(" ".join(current_chunk_parts).strip()) return segment_chunks def chunk_sentence(sentence, max_words): """ Splits text into chunks based on max words, prioritizing sentence-ending punctuation (. ! ?), then commas (,) if the chunk is already >= max_words, falling back to word split. Processes the input line as potentially containing multiple sentences. """ if not sentence or sentence.isspace(): return [] all_final_chunks = [] # 1. Split the input line into potential "sentence segments" at . ! ? # Use regex split with lookbehind to split *after* the punctuation and space. # This yields segments that end in . ! ? (except possibly the very last segment). # Example: "Hello world. How are you? And you?" -> ["Hello world.", "How are you?", "And you?"] # Example: "Part one, part two. Part three." -> ["Part one, part two.", "Part three."] # Example: "No punctuation here" -> ["No punctuation here"] sentence_segments = re.split(r'(?<=[.!?])\s*', sentence) # Filter out empty strings that might result from splitting sentence_segments = [s.strip() for s in sentence_segments if s.strip()] # 2. Process each sentence segment for segment in sentence_segments: segment_word_count = len(segment.split()) if segment_word_count <= max_words: # Segment is short enough, add directly all_final_chunks.append(segment) else: # Segment is too long, apply comma splitting or fallback word splitting comma_based_chunks = split_long_segment_by_comma_or_fallback(segment, max_words) all_final_chunks.extend(comma_based_chunks) # Ensure no empty strings sneak through at the end return [chunk for chunk in all_final_chunks if chunk.strip()] # --- Define the BATCH translation function --- # Add GPU decorator for Spaces (adjust duration if needed) @spaces.GPU def translate_batch(text_input): """ Translates multi-line input text using batching and sentence chunking. Assumes auto-detection of language direction (no prefixes). """ if not text_input or text_input.strip() == "": return "[Error] Please enter some text to translate." print(f"Received input block for batch translation.") # 1. Split input into potential sentences (lines) and clean # Then chunk each line using the sophisticated chunk_sentence function lines = [line.strip() for line in text_input.splitlines() if line.strip()] if not lines: return "[Info] No valid text lines found in input." # 2. Chunk lines using the new logic all_chunks = [] for line in lines: # Process each line as a potential multi-sentence block for chunking line_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK) all_chunks.extend(line_chunks) if not all_chunks: return "[Info] No text chunks generated after processing input." print(f"Processing {len(all_chunks)} chunks in batches...") # 3. Process chunks in batches all_translations = [] num_batches = math.ceil(len(all_chunks) / BATCH_SIZE) for i in range(num_batches): batch_start = i * BATCH_SIZE batch_end = batch_start + BATCH_SIZE batch_chunks = all_chunks[batch_start:batch_end] print(f" Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)") # Tokenize the batch try: inputs = tokenizer( batch_chunks, return_tensors="pt", padding=True, truncation=True, max_length=1024 ).to(device) max_length = 1024 # your specified model max length max_input_length = inputs["input_ids"].shape[1] max_new_tokens = min(int(max_input_length * 1.2), max_length) print(f"Tokenized input (max_length={max_length})") for i, (text, input_ids) in enumerate(zip(batch_chunks, inputs["input_ids"])): print(f" Input {i + 1}: {len(input_ids)} tokens") print(f" Chunk {i + 1}: {repr(text)}...") # Print first 100 chars to keep output manageableu for idx, ids in enumerate(inputs["input_ids"]): print(f" Input {idx+1}: {len(ids)} tokens") except Exception as e: print(f"Error during batch tokenization: {e}") return "[Error] Tokenization failed for a batch." # Generate translations for the batch try: with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, num_beams=4, # no_repeat_ngram_size=3, early_stopping=False, return_dict_in_generate=True, output_scores=True ) print(f" Generation completed with max_new_tokens={max_new_tokens}") sequences = outputs.sequences for idx, seq in enumerate(sequences): print(f" Output {idx+1}: {len(seq)} tokens") batch_translations = tokenizer.batch_decode(sequences, skip_special_tokens=True) all_translations.extend(batch_translations) except Exception as e: print(f"Error during batch generation/decoding: {e}") return "[Error] Translation generation failed for a batch." # 4. Join translated chunks back together # Simple join with newline. The chunking logic aims to keep sentences/clauses together, # so joining by newline should preserve the overall structure reasonably well, # though it might not exactly match the original line breaks if chunking occurred within an original line. final_output = "\n".join(all_translations) print("Batch translation finished.") return final_output # --- Create Gradio Interface for Batch Translation --- input_textbox = gr.Textbox( lines=10, # Allow more lines for batch input label="Input Text (Polish or English - Enter multiple lines/sentences)", placeholder=f"Enter text here. Longer sentences/lines will be split into chunks (max {MAX_WORDS_PER_CHUNK} words) prioritizing . ! ? and , breaks." ) output_textbox = gr.Textbox(label="Translation Output", lines=10) # Interface definition interface = gr.Interface( fn=translate_batch, # Use the batch function inputs=input_textbox, outputs=output_textbox, title="🇵🇱 <-> 🇬🇧 Batch ByT5 Translator (Auto-Detect, Smart Chunking)", description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nText is automatically split into chunks of max {MAX_WORDS_PER_CHUNK} words, prioritizing breaks at . ! ? and ,", article="Enter text (you can paste multiple paragraphs or sentences). Click Submit to translate.\n\nChunking Logic:\n1. The entire input box content is split into potential 'sentence segments' using . ! ? as delimiters.\n2. Each segment is checked for word count.\n3. If a segment is <= {MAX_WORDS_PER_CHUNK} words, it's treated as a single chunk.\n4. If a segment is > {MAX_WORDS_PER_CHUNK} words, it's further split internally using commas (,) as preferred break points.\n5. If a long segment has no commas, or comma splitting isn't sufficient, it falls back to breaking purely by word count near {MAX_WORDS_PER_CHUNK} to avoid excessively long chunks.\n6. These final chunks are batched and translated.", allow_flagging="never" ) # --- Launch the App --- if __name__ == "__main__": # Set share=True for a public link if running locally, not needed on Spaces interface.launch()