jsbeaudry's picture
Update app.py
ac738a4 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load tokenizer and model
model_name = "jsbeaudry/creole-translation-nllb-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Assuming the necessary imports and model loading from the previous code are in place
def translate_text(text):
# Set the source and target language codes
src_lang = "eng_Latn"
tgt_lang = "hat_Latn"
tokenizer_ = tokenizer
model_ = model
# Set tokenizer to source language
tokenizer_.src_lang = src_lang
# Tokenize the input
inputs = tokenizer_(text, return_tensors="pt")
# Find the BOS token ID for the target language
forced_bos_token_id = tokenizer_.convert_tokens_to_ids(tgt_lang)
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model_ = model_.to(device)
inputs = inputs.to(device)
# Generate translation
generated_tokens = model_.generate(
**inputs,
forced_bos_token_id=forced_bos_token_id,
max_length=100
)
# Decode and print
translated = tokenizer_.batch_decode(generated_tokens, skip_special_tokens=True)
return translated[0]
iface = gr.Interface(
fn=translate_text,
inputs=gr.Textbox(lines=5, placeholder="Enter text to translate"),
outputs="text",
title="English to Haitian Creole Translation",
description="Translate English text to Haitian Creole using a fine-tuned NLLB model."
)
iface.launch()