Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TF... | |
import torch # Or import tensorflow as tf | |
import os # <--- ADDED: To access environment variables | |
# --- Configuration --- | |
# Use the EXACT Hub ID of your PRIVATE model | |
MODEL_PATH = "Gregniuki/pl-en-pl-v2" | |
# --- Get Hugging Face Token from Secrets --- # <--- ADDED SECTION | |
HF_AUTH_TOKEN = os.getenv("HF_TOKEN") | |
if HF_AUTH_TOKEN is None: | |
print("Warning: HF_TOKEN secret not found. Loading model without authentication.") | |
# Optionally, raise an error if the token is absolutely required: | |
# raise ValueError("HF_TOKEN secret is missing, cannot load private model.") | |
# --- END ADDED SECTION --- | |
# --- Load Model and Tokenizer (do this once on startup) --- | |
print(f"Loading model and tokenizer from: {MODEL_PATH}") | |
try: | |
# --- MODIFIED: Pass the token --- | |
tokenizer = AutoTokenizer.from_pretrained( | |
MODEL_PATH, | |
token=HF_AUTH_TOKEN, # <--- ADDED | |
trust_remote_code=False # Set to True if model requires it | |
) | |
# --- MODIFIED: Pass the token --- | |
# PyTorch | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
MODEL_PATH, | |
token=HF_AUTH_TOKEN, # <--- ADDED | |
trust_remote_code=False # Set to True if model requires it | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print(f"Using PyTorch model on device: {device}") | |
# # TensorFlow (uncomment if using TF) | |
# from transformers import TFAutoModelForSeq2SeqLM | |
# import tensorflow as tf | |
# model = TFAutoModelForSeq2SeqLM.from_pretrained( | |
# MODEL_PATH, | |
# token=HF_AUTH_TOKEN, # <--- ADDED | |
# trust_remote_code=False | |
# ) | |
# print("Using TensorFlow model.") | |
# device = "cpu" | |
model.eval() | |
print("Model and tokenizer loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model/tokenizer: {e}") | |
# Add more specific error handling if needed (e.g., check for 401 Unauthorized) | |
if "401 Client Error" in str(e): | |
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}." | |
else: | |
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}" | |
raise gr.Error(error_message) | |
# --- Define the translation function (KEEP AS IS, depending on prefix/no-prefix) --- | |
def translate_text(text_input): # Or def translate_text(text_input, direction): | |
# ... (your existing translation logic remains the same) ... | |
if not text_input or text_input.strip() == "": | |
return "[Error] Please enter some text to translate." | |
print(f"Received input: '{text_input}'") | |
# Tokenize | |
try: | |
# PyTorch | |
inputs = tokenizer(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
# # TensorFlow | |
# inputs = tokenizer(text_input, return_tensors="tf", padding=True, truncation=True, max_length=512) | |
except Exception as e: | |
print(f"Error during tokenization: {e}") | |
return f"[Error] Tokenization failed: {e}" | |
# Generate | |
try: | |
# PyTorch | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_length=512, num_beams=4, early_stopping=True | |
) | |
output_ids = outputs[0] | |
# # TensorFlow | |
# outputs = model.generate( | |
# inputs['input_ids'], attention_mask=inputs['attention_mask'], | |
# max_length=512, num_beams=4, early_stopping=True | |
# ) | |
# output_ids = outputs[0] | |
translation = tokenizer.decode(output_ids, skip_special_tokens=True) | |
print(f"Generated translation: '{translation}'") | |
return translation | |
except Exception as e: | |
print(f"Error during generation/decoding: {e}") | |
return f"[Error] Translation generation failed: {e}" | |
# --- Create Gradio Interface (KEEP AS IS, depending on prefix/no-prefix) --- | |
# Example for no-prefix model: | |
input_textbox = gr.Textbox(lines=4, label="Input Text (Polish or English)", placeholder="Enter text here...") | |
output_textbox = gr.Textbox(label="Translation") | |
interface = gr.Interface( | |
fn=translate_text, | |
inputs=input_textbox, | |
outputs=output_textbox, | |
title="🇵🇱 <-> 🇬🇧 Auto-Detecting ByT5 Translator", | |
description=f"Translate text between Polish and English.\nModel: {MODEL_PATH}", | |
article="Enter text and click Submit.", | |
allow_flagging="never" | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
interface.launch() |