jsbeaudry commited on
Commit
ac738a4
·
verified ·
1 Parent(s): 80d9532

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -6,45 +6,46 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  model_name = "jsbeaudry/creole-translation-nllb-600M"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
9
 
10
- # Supported languages
11
- language_codes = {
12
- "English": "eng_Latn",
13
- "Haitian Creole": "hat_Latn"
14
- }
 
15
 
16
- # Translation function
17
- def translate_text(text, src_lang_name, tgt_lang_name):
18
- src_lang = language_codes[src_lang_name]
19
- tgt_lang = language_codes[tgt_lang_name]
20
 
21
- tokenizer.src_lang = src_lang
22
- inputs = tokenizer(text, return_tensors="pt")
23
- forced_bos_token_id = tokenizer.convert_tokens_to_ids(tgt_lang)
24
 
 
 
 
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
- model.to(device)
27
  inputs = inputs.to(device)
28
 
29
- generated_tokens = model.generate(
 
30
  **inputs,
31
  forced_bos_token_id=forced_bos_token_id,
32
  max_length=100
33
  )
34
- translated = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
 
 
35
  return translated[0]
36
 
37
- # Gradio interface
38
  iface = gr.Interface(
39
  fn=translate_text,
40
- inputs=[
41
- gr.Textbox(lines=5, placeholder="Enter text to translate", label="Input Text"),
42
- gr.Dropdown(choices=list(language_codes.keys()), value="English", label="Source Language"),
43
- gr.Dropdown(choices=list(language_codes.keys()), value="Haitian Creole", label="Target Language")
44
- ],
45
- outputs=gr.Textbox(label="Translated Text"),
46
- title="Multilingual Translation (English ↔ Haitian Creole)",
47
- description="Translate text between English and Haitian Creole using a fine-tuned NLLB model."
48
  )
49
 
50
- iface.launch()
 
6
  model_name = "jsbeaudry/creole-translation-nllb-600M"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
+ # Assuming the necessary imports and model loading from the previous code are in place
10
 
11
+ def translate_text(text):
12
+ # Set the source and target language codes
13
+ src_lang = "eng_Latn"
14
+ tgt_lang = "hat_Latn"
15
+ tokenizer_ = tokenizer
16
+ model_ = model
17
 
18
+ # Set tokenizer to source language
19
+ tokenizer_.src_lang = src_lang
 
 
20
 
21
+ # Tokenize the input
22
+ inputs = tokenizer_(text, return_tensors="pt")
 
23
 
24
+ # Find the BOS token ID for the target language
25
+ forced_bos_token_id = tokenizer_.convert_tokens_to_ids(tgt_lang)
26
+
27
+ # Move model to GPU if available
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
+ model_ = model_.to(device)
30
  inputs = inputs.to(device)
31
 
32
+ # Generate translation
33
+ generated_tokens = model_.generate(
34
  **inputs,
35
  forced_bos_token_id=forced_bos_token_id,
36
  max_length=100
37
  )
38
+
39
+ # Decode and print
40
+ translated = tokenizer_.batch_decode(generated_tokens, skip_special_tokens=True)
41
  return translated[0]
42
 
 
43
  iface = gr.Interface(
44
  fn=translate_text,
45
+ inputs=gr.Textbox(lines=5, placeholder="Enter text to translate"),
46
+ outputs="text",
47
+ title="English to Haitian Creole Translation",
48
+ description="Translate English text to Haitian Creole using a fine-tuned NLLB model."
 
 
 
 
49
  )
50
 
51
+ iface.launch()