translate2 / app1.py
Gregniuki's picture
Rename app.py to app1.py
d048661 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM # Or TF...
import torch # Or import tensorflow as tf
import os # <--- ADDED: To access environment variables
# --- Configuration ---
# Use the EXACT Hub ID of your PRIVATE model
MODEL_PATH = "Gregniuki/pl-en-pl-v2"
# --- Get Hugging Face Token from Secrets --- # <--- ADDED SECTION
HF_AUTH_TOKEN = os.getenv("HF_TOKEN")
if HF_AUTH_TOKEN is None:
print("Warning: HF_TOKEN secret not found. Loading model without authentication.")
# Optionally, raise an error if the token is absolutely required:
# raise ValueError("HF_TOKEN secret is missing, cannot load private model.")
# --- END ADDED SECTION ---
# --- Load Model and Tokenizer (do this once on startup) ---
print(f"Loading model and tokenizer from: {MODEL_PATH}")
try:
# --- MODIFIED: Pass the token ---
tokenizer = AutoTokenizer.from_pretrained(
MODEL_PATH,
token=HF_AUTH_TOKEN, # <--- ADDED
trust_remote_code=False # Set to True if model requires it
)
# --- MODIFIED: Pass the token ---
# PyTorch
model = AutoModelForSeq2SeqLM.from_pretrained(
MODEL_PATH,
token=HF_AUTH_TOKEN, # <--- ADDED
trust_remote_code=False # Set to True if model requires it
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Using PyTorch model on device: {device}")
# # TensorFlow (uncomment if using TF)
# from transformers import TFAutoModelForSeq2SeqLM
# import tensorflow as tf
# model = TFAutoModelForSeq2SeqLM.from_pretrained(
# MODEL_PATH,
# token=HF_AUTH_TOKEN, # <--- ADDED
# trust_remote_code=False
# )
# print("Using TensorFlow model.")
# device = "cpu"
model.eval()
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model/tokenizer: {e}")
# Add more specific error handling if needed (e.g., check for 401 Unauthorized)
if "401 Client Error" in str(e):
error_message = f"Authentication failed. Ensure the HF_TOKEN secret has read access to {MODEL_PATH}."
else:
error_message = f"Failed to load model from {MODEL_PATH}. Error: {e}"
raise gr.Error(error_message)
# --- Define the translation function (KEEP AS IS, depending on prefix/no-prefix) ---
def translate_text(text_input): # Or def translate_text(text_input, direction):
# ... (your existing translation logic remains the same) ...
if not text_input or text_input.strip() == "":
return "[Error] Please enter some text to translate."
print(f"Received input: '{text_input}'")
# Tokenize
try:
# PyTorch
inputs = tokenizer(text_input, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
# # TensorFlow
# inputs = tokenizer(text_input, return_tensors="tf", padding=True, truncation=True, max_length=512)
except Exception as e:
print(f"Error during tokenization: {e}")
return f"[Error] Tokenization failed: {e}"
# Generate
try:
# PyTorch
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=512, num_beams=4, early_stopping=True
)
output_ids = outputs[0]
# # TensorFlow
# outputs = model.generate(
# inputs['input_ids'], attention_mask=inputs['attention_mask'],
# max_length=512, num_beams=4, early_stopping=True
# )
# output_ids = outputs[0]
translation = tokenizer.decode(output_ids, skip_special_tokens=True)
print(f"Generated translation: '{translation}'")
return translation
except Exception as e:
print(f"Error during generation/decoding: {e}")
return f"[Error] Translation generation failed: {e}"
# --- Create Gradio Interface (KEEP AS IS, depending on prefix/no-prefix) ---
# Example for no-prefix model:
input_textbox = gr.Textbox(lines=4, label="Input Text (Polish or English)", placeholder="Enter text here...")
output_textbox = gr.Textbox(label="Translation")
interface = gr.Interface(
fn=translate_text,
inputs=input_textbox,
outputs=output_textbox,
title="🇵🇱 <-> 🇬🇧 Auto-Detecting ByT5 Translator",
description=f"Translate text between Polish and English.\nModel: {MODEL_PATH}",
article="Enter text and click Submit.",
allow_flagging="never"
)
# --- Launch the App ---
if __name__ == "__main__":
interface.launch()