|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
model_name = "jsbeaudry/creole-translation-nllb-600M" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
|
def translate_text(text): |
|
|
|
src_lang = "eng_Latn" |
|
tgt_lang = "hat_Latn" |
|
tokenizer_ = tokenizer |
|
model_ = model |
|
|
|
|
|
tokenizer_.src_lang = src_lang |
|
|
|
|
|
inputs = tokenizer_(text, return_tensors="pt") |
|
|
|
|
|
forced_bos_token_id = tokenizer_.convert_tokens_to_ids(tgt_lang) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_ = model_.to(device) |
|
inputs = inputs.to(device) |
|
|
|
|
|
generated_tokens = model_.generate( |
|
**inputs, |
|
forced_bos_token_id=forced_bos_token_id, |
|
max_length=100 |
|
) |
|
|
|
|
|
translated = tokenizer_.batch_decode(generated_tokens, skip_special_tokens=True) |
|
return translated[0] |
|
|
|
iface = gr.Interface( |
|
fn=translate_text, |
|
inputs=gr.Textbox(lines=5, placeholder="Enter text to translate"), |
|
outputs="text", |
|
title="English to Haitian Creole Translation", |
|
description="Translate English text to Haitian Creole using a fine-tuned NLLB model." |
|
) |
|
|
|
iface.launch() |