Spaces:
Sleeping
Sleeping
try: | |
import spaces | |
print("'spaces' module imported successfully.") | |
except ImportError: | |
print("Warning: 'spaces' module not found. Using dummy decorator for local execution.") | |
# Define a dummy decorator that does nothing if 'spaces' isn't available | |
class DummySpaces: | |
def GPU(self, *args, **kwargs): | |
def decorator(func): | |
# This dummy decorator just returns the original function | |
print(f"Note: Dummy @GPU decorator used for function '{func.__name__}'.") | |
return func | |
return decorator | |
spaces = DummySpaces() # Create an instance of the dummy class | |
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TFAutoModelForSeq2SeqLM | |
import torch # Or import tensorflow as tf | |
import os | |
import math | |
# Requires Gradio version supporting spaces.GPU decorator if running on Spaces | |
# Might need: from gradio.external import spaces <- if spaces not directly available | |
#import gradio.external as spaces # Use this import path | |
from huggingface_hub import hf_hub_download | |
# --- Configuration --- | |
# IMPORTANT: REPLACE THIS with your model's Hugging Face Hub ID or local path | |
MODEL_PATH = "Gregniuki/pl-en-pl-v2" # Use your actual model path | |
MAX_WORDS_PER_CHUNK = 40 | |
BATCH_SIZE = 4 # Adjust based on GPU memory / desired throughput | |
# --- Device Setup (Zero GPU Support) --- | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
print("GPU detected. Using CUDA.") | |
else: | |
device = torch.device("cpu") | |
print("No GPU detected. Using CPU.") | |
# --- Get Hugging Face Token from Secrets for Private Models --- | |
HF_AUTH_TOKEN = os.getenv("HF_TOKEN") | |
if MODEL_PATH and "/" in MODEL_PATH and not os.path.exists(MODEL_PATH): # Rough check if it's likely a Hub ID | |
if HF_AUTH_TOKEN is None: | |
print(f"Warning: HF_TOKEN secret not found. Trying to load {MODEL_PATH} without authentication.") | |
else: | |
print("HF_TOKEN found. Using token for model loading.") | |
else: | |
print(f"Loading model from local path: {MODEL_PATH}") | |
HF_AUTH_TOKEN = None # Don't use token for local paths | |
# --- Load Model and Tokenizer (once on startup) --- | |
print(f"Loading model and tokenizer from: {MODEL_PATH}") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATH, | |
token=HF_AUTH_TOKEN, | |
trust_remote_code=False | |
) | |
# --- Choose the correct model class --- | |
# PyTorch (most common) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATH, | |
token=HF_AUTH_TOKEN, | |
trust_remote_code=False | |
) | |
model.to(device) # Move model to the determined device | |
model.eval() # Set model to evaluation mode | |
print(f"Using PyTorch model on device: {device}") | |
# # TensorFlow (uncomment if your model is TF) | |
# from transformers import TFAutoModelForSeq2SeqLM | |
# import tensorflow as tf | |
# model = TFAutoModelForSeq2SeqLM.from_pretrained( | |
# MODEL_PATH, | |
# token=HF_AUTH_TOKEN, | |
# trust_remote_code=False | |
# ) | |
# # TF device placement is often automatic or managed via strategies | |
# print("Using TensorFlow model.") | |
print("Model and tokenizer loaded successfully.") | |
except Exception as e: | |
print(f"FATAL Error loading model/tokenizer: {e}") | |
if "401 Client Error" in str(e): | |
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}." | |
else: | |
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}" | |
# Raise error to prevent app launch if model loading fails | |
raise RuntimeError(error_message) | |
# --- Helper Function for Chunking Sentences --- | |
def chunk_sentence(sentence, max_words): | |
"""Splits a sentence into chunks of max_words.""" | |
if not sentence or sentence.isspace(): | |
return [] | |
words = sentence.split() # Simple space splitting | |
chunks = [] | |
current_chunk = [] | |
for word in words: | |
current_chunk.append(word) | |
if len(current_chunk) >= max_words: | |
chunks.append(" ".join(current_chunk)) | |
current_chunk = [] | |
if current_chunk: # Add any remaining words | |
chunks.append(" ".join(current_chunk)) | |
return chunks | |
# --- Define the BATCH translation function --- | |
# Add GPU decorator for Spaces (adjust duration if needed) | |
def translate_batch(text_input): | |
""" | |
Translates multi-line input text using batching and sentence chunking. | |
Assumes auto-detection of language direction (no prefixes). | |
""" | |
if not text_input or text_input.strip() == "": | |
return "[Error] Please enter some text to translate." | |
print(f"Received input block for batch translation.") | |
# 1. Split input into potential sentences (lines) and clean | |
lines = [line.strip() for line in text_input.splitlines() if line.strip()] | |
if not lines: | |
return "[Info] No valid text lines found in input." | |
# 2. Chunk long sentences | |
all_chunks = [] | |
for line in lines: | |
sentence_chunks = chunk_sentence(line, MAX_WORDS_PER_CHUNK) | |
all_chunks.extend(sentence_chunks) | |
if not all_chunks: | |
return "[Info] No text chunks generated after processing input." | |
print(f"Processing {len(all_chunks)} chunks in batches...") | |
# 3. Process chunks in batches | |
all_translations = [] | |
num_batches = math.ceil(len(all_chunks) / BATCH_SIZE) | |
for i in range(num_batches): | |
batch_start = i * BATCH_SIZE | |
batch_end = batch_start + BATCH_SIZE | |
batch_chunks = all_chunks[batch_start:batch_end] | |
print(f" Processing batch {i+1}/{num_batches} ({len(batch_chunks)} chunks)") | |
# Tokenize the batch | |
try: | |
# PyTorch | |
inputs = tokenizer(batch_chunks, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
# # TensorFlow | |
# inputs = tokenizer(batch_chunks, return_tensors="tf", padding=True, truncation=True, max_length=512) | |
except Exception as e: | |
print(f"Error during batch tokenization: {e}") | |
# Return partial results or a general error | |
return "[Error] Tokenization failed for a batch." | |
# Generate translations for the batch | |
try: | |
# PyTorch | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_length=1024, | |
num_beams=8, | |
early_stopping=False | |
) | |
# output_ids shape: [batch_size, sequence_length] | |
# # TensorFlow | |
# outputs = model.generate( | |
# inputs['input_ids'], | |
# attention_mask=inputs['attention_mask'], | |
# max_length=512, | |
# num_beams=4, | |
# early_stopping=True | |
# ) | |
# outputs is typically a tensor of shape [batch_size, sequence_length] | |
# Decode the batch results | |
batch_translations = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
all_translations.extend(batch_translations) | |
except Exception as e: | |
print(f"Error during batch generation/decoding: {e}") | |
# Return partial results or a general error | |
return "[Error] Translation generation failed for a batch." | |
# 4. Join translated chunks back together | |
# Simple join with newline, might not perfectly preserve original structure if chunking happened mid-sentence. | |
final_output = "\n".join(all_translations) | |
print("Batch translation finished.") | |
return final_output | |
# --- Create Gradio Interface for Batch Translation --- | |
input_textbox = gr.Textbox( | |
lines=10, # Allow more lines for batch input | |
label="Input Text (Polish or English - One sentence per line recommended)", | |
placeholder="Enter text here. Longer sentences will be split into chunks (max 20 words)." | |
) | |
output_textbox = gr.Textbox(label="Translation Output", lines=10) | |
# Interface definition | |
interface = gr.Interface( | |
fn=translate_batch, # Use the batch function | |
inputs=input_textbox, | |
outputs=output_textbox, | |
title="🇵🇱 <-> 🇬🇧 Batch ByT5 Translator (Auto-Detect, Chunking)", | |
description=f"Translate multiple lines of text between Polish and English.\nModel: {MODEL_PATH}\nLong sentences are automatically split into chunks of max {MAX_WORDS_PER_CHUNK} words.", | |
article="Enter text (ideally one sentence per line). Click Submit to translate all lines.", | |
allow_flagging="never" | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
# Set share=True for a public link if running locally, not needed on Spaces | |
interface.launch() |