VDNT11 commited on
Commit
a641f96
·
verified ·
1 Parent(s): 059ddc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -67
app.py CHANGED
@@ -8,80 +8,136 @@ from gtts import gTTS
8
  import soundfile as sf
9
  from transformers import VitsTokenizer, VitsModel, set_seed
10
 
11
- # Set Hugging Face token (via environment or user input)
12
- hf_token = st.text_input("Enter your Hugging Face API token", type="password")
 
 
13
 
14
- # Ensure token is provided
15
- if hf_token:
16
 
17
- # Initialize BLIP for image captioning
18
- blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
19
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
- # Function to generate captions
22
- def generate_caption(image_path):
23
- image = Image.open(image_path).convert("RGB")
24
- inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
25
- with torch.no_grad():
26
- generated_ids = blip_model.generate(**inputs)
27
- caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
28
- return caption
29
-
30
- # Function to load FB MMS TTS model with Hugging Face token
31
- def load_fbmms_model(model_name, hf_token):
32
- tokenizer = VitsTokenizer.from_pretrained(model_name, use_auth_token=hf_token)
33
- model = VitsModel.from_pretrained(model_name, use_auth_token=hf_token)
34
- return tokenizer, model
35
-
36
- # Function to generate audio using Facebook MMS-TTS
37
- def generate_audio_fbmms(text, model_name, hf_token, output_file):
38
- tokenizer, model = load_fbmms_model(model_name, hf_token)
39
- inputs = tokenizer(text=text, return_tensors="pt")
40
- set_seed(555)
41
- with torch.no_grad():
42
- outputs = model(**inputs)
43
- waveform = outputs.waveform[0].cpu().numpy()
44
- sf.write(output_file, waveform, samplerate=model.config.sampling_rate)
45
- return output_file
46
-
47
- # Streamlit UI for TTS method
48
- tts_method = st.selectbox(
49
- "Choose Text-to-Speech Method",
50
- options=["gTTS (Google)", "Facebook MMS TTS"],
51
- index=0 # Default to gTTS
52
  )
53
 
54
- # Select target languages with human-readable names
55
- language_options = {
56
- "hin_Deva": "Hindi (Devanagari)",
57
- "mar_Deva": "Marathi (Devanagari)",
58
- "guj_Gujr": "Gujarati (Gujarati)",
59
- "urd_Arab": "Urdu (Arabic)"
60
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  target_languages = st.multiselect(
63
- "Select target languages for translation",
64
- options=list(language_options.keys()),
65
- format_func=lambda x: language_options[x]
66
  )
67
 
68
- if uploaded_image is not None and target_languages:
69
- caption = generate_caption(uploaded_image)
 
 
 
 
 
 
 
70
  for lang in target_languages:
71
- st.write(f"Generating audio for {lang}...")
72
- output_file = f"{lang}_audio.mp3"
73
-
74
- if tts_method == "gTTS (Google)":
75
- lang_code = {
76
- "hin_Deva": "hi",
77
- "mar_Deva": "mr",
78
- "guj_Gujr": "gu",
79
- "urd_Arab": "ur"
80
- }.get(lang, "en")
81
- audio_file = generate_audio_gtts(translations[lang], lang_code, output_file)
82
- else:
83
- model_name = f"facebook/mms-tts-{lang}"
84
- audio_file = generate_audio_fbmms(translations[lang], model_name, hf_token, output_file)
85
-
86
- if audio_file:
87
- st.audio(audio_file)
 
8
  import soundfile as sf
9
  from transformers import VitsTokenizer, VitsModel, set_seed
10
 
11
+ # Clone and Install IndicTransToolkit repository
12
+ if not os.path.exists('IndicTransToolkit'):
13
+ os.system('git clone https://github.com/VarunGumma/IndicTransToolkit')
14
+ os.system('cd IndicTransToolkit && python3 -m pip install --editable ./')
15
 
16
+ # Ensure that IndicTransToolkit is installed and used properly
17
+ from IndicTransToolkit import IndicProcessor
18
 
19
+ # Initialize BLIP for image captioning
20
+ blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
21
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda" if torch.cuda.is_available() else "cpu")
22
 
23
+ # Function to generate captions
24
+ def generate_caption(image_path):
25
+ image = Image.open(image_path).convert("RGB")
26
+ inputs = blip_processor(image, "image of", return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
27
+ with torch.no_grad():
28
+ generated_ids = blip_model.generate(**inputs)
29
+ caption = blip_processor.decode(generated_ids[0], skip_special_tokens=True)
30
+ return caption
31
+
32
+ # Function for translation using IndicTrans2
33
+ def translate_caption(caption, target_languages):
34
+ # Load model and tokenizer
35
+ model_name = "ai4bharat/indictrans2-en-indic-1B"
36
+ tokenizer_IT2 = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
37
+ model_IT2 = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
38
+ model_IT2 = torch.quantization.quantize_dynamic(
39
+ model_IT2, {torch.nn.Linear}, dtype=torch.qint8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  )
41
 
42
+ ip = IndicProcessor(inference=True)
43
+
44
+ # Source language (English)
45
+ src_lang = "eng_Latn"
46
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
47
+ model_IT2.to(DEVICE) # Move model to the device
48
+
49
+ # Integrating with workflow now
50
+ input_sentences = [caption]
51
+ translations = {}
52
+
53
+ for tgt_lang in target_languages:
54
+ # Preprocess input sentences
55
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
56
+
57
+ # Tokenize the sentences and generate input encodings
58
+ inputs = tokenizer_IT2(batch, truncation=True, padding="longest", return_tensors="pt").to(DEVICE)
59
+
60
+ # Generate translations using the model
61
+ with torch.no_grad():
62
+ generated_tokens = model_IT2.generate(
63
+ **inputs,
64
+ use_cache=True,
65
+ min_length=0,
66
+ max_length=256,
67
+ num_beams=5,
68
+ num_return_sequences=1,
69
+ )
70
+
71
+ # Decode the generated tokens into text
72
+ with tokenizer_IT2.as_target_tokenizer():
73
+ generated_tokens = tokenizer_IT2.batch_decode(generated_tokens.detach().cpu().tolist(), skip_special_tokens=True, clean_up_tokenization_spaces=True)
74
+
75
+ # Postprocess the translations
76
+ translated_texts = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
77
+ translations[tgt_lang] = translated_texts[0]
78
+
79
+ return translations
80
+
81
+ # Function to generate audio using gTTS
82
+ def generate_audio_gtts(text, lang_code, output_file):
83
+ tts = gTTS(text=text, lang=lang_code)
84
+ tts.save(output_file)
85
+ return output_file
86
+
87
+ # Function to generate audio using Facebook MMS-TTS
88
+ def generate_audio_fbmms(text, model_name, output_file):
89
+ tokenizer = VitsTokenizer.from_pretrained(model_name)
90
+ model = VitsModel.from_pretrained(model_name)
91
+ inputs = tokenizer(text=text, return_tensors="pt")
92
+ set_seed(555)
93
+ with torch.no_grad():
94
+ outputs = model(**inputs)
95
+ waveform = outputs.waveform[0].cpu().numpy()
96
+ sf.write(output_file, waveform, samplerate=model.config.sampling_rate)
97
+ return output_file
98
 
99
+ # Streamlit UI
100
+ st.title("Multilingual Assistive Model")
101
+
102
+ uploaded_image = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
103
+
104
+ if uploaded_image is not None:
105
+ # Display the uploaded image
106
+ image = Image.open(uploaded_image)
107
+ st.image(image, caption="Uploaded Image", use_column_width=True)
108
+
109
+ # Generate Caption
110
+ st.write("Generating Caption...")
111
+ caption = generate_caption(uploaded_image)
112
+ st.write(f"Caption: {caption}")
113
+
114
+ # Select target languages for translation
115
  target_languages = st.multiselect(
116
+ "Select target languages for translation",
117
+ ["hin_Deva", "mar_Deva", "guj_Gujr", "urd_Arab"], # Add more languages as needed
118
+ ["hin_Deva", "mar_Deva"]
119
  )
120
 
121
+ # Generate Translations
122
+ if target_languages:
123
+ st.write("Translating Caption...")
124
+ translations = translate_caption(caption, target_languages)
125
+ st.write("Translations:")
126
+ for lang, translation in translations.items():
127
+ st.write(f"{lang}: {translation}")
128
+
129
+ # Default to gTTS for TTS
130
  for lang in target_languages:
131
+ st.write(f"Using gTTS for {lang}...")
132
+ lang_code = {
133
+ "hin_Deva": "hi", # Hindi
134
+ "guj_Gujr": "gu", # Gujarati
135
+ "urd_Arab": "ur" # Urdu
136
+ }.get(lang, "en")
137
+ output_file = f"{lang}_gTTS.mp3"
138
+ audio_file = generate_audio_gtts(translations[lang], lang_code, output_file)
139
+
140
+ st.write(f"Playing {lang} audio:")
141
+ st.audio(audio_file)
142
+ else:
143
+ st.write("Upload an image to start.")