Upload 8 files
Browse files- app.py +144 -0
- create_srt.py +16 -0
- custom_theme.py +69 -0
- install.sh +33 -0
- process_yt_video.py +35 -0
- transcription.py +52 -0
- translation.py +219 -0
- transliteration.py +1 -0
app.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transcription import transcribe_audio
|
3 |
+
from translation import translate_text
|
4 |
+
|
5 |
+
from process_yt_video import download_audio,get_embed_url
|
6 |
+
from create_srt import create_srt
|
7 |
+
from custom_theme import CustomTheme
|
8 |
+
|
9 |
+
stheme = CustomTheme()
|
10 |
+
|
11 |
+
# Function to handle transcription
|
12 |
+
def process_transcription(audio_file, youtube_url):
|
13 |
+
if youtube_url:
|
14 |
+
audio_file = download_audio(youtube_url)
|
15 |
+
print(f"Downloaded audio file from YouTube: {audio_file}")
|
16 |
+
|
17 |
+
if not audio_file:
|
18 |
+
return None, "No audio provided!", None, None
|
19 |
+
|
20 |
+
print(f"Processing audio file: {audio_file}")
|
21 |
+
detected_lang, transcription = transcribe_audio(audio_file)
|
22 |
+
|
23 |
+
if not transcription:
|
24 |
+
return "Error in transcription", None
|
25 |
+
|
26 |
+
return detected_lang, transcription
|
27 |
+
|
28 |
+
# Function to handle translation
|
29 |
+
def process_translation(transcription, target_lang, detected_lang):
|
30 |
+
if not transcription:
|
31 |
+
return "Please transcribe first!"
|
32 |
+
|
33 |
+
translated_text = translate_text(transcription, target_lang, detected_lang)
|
34 |
+
return translated_text
|
35 |
+
|
36 |
+
# Function to handle subtitle generation
|
37 |
+
def process_subtitle(transcription, translation):
|
38 |
+
if not transcription or not translation:
|
39 |
+
return "Please transcribe and translate first!", None
|
40 |
+
|
41 |
+
subtitle_file = create_srt(transcription, translation)
|
42 |
+
return "Subtitle generated successfully!", subtitle_file
|
43 |
+
|
44 |
+
# Function to handle transliteration
|
45 |
+
'''def process_transliteration(translated_text):
|
46 |
+
if not translated_text:
|
47 |
+
return "Please translate first!"
|
48 |
+
return "hello"'''
|
49 |
+
|
50 |
+
# Function to update embedded YouTube video player
|
51 |
+
def update_video(youtube_url):
|
52 |
+
embed_url = get_embed_url(youtube_url)
|
53 |
+
return f"<iframe width='560' height='315' src='{embed_url}' frameborder='0' allowfullscreen></iframe>" if embed_url else ""
|
54 |
+
|
55 |
+
with gr.Blocks(theme=stheme) as demo:
|
56 |
+
gr.Markdown("# Voice-to-Text Translation System", elem_id="title")
|
57 |
+
|
58 |
+
with gr.Row():
|
59 |
+
|
60 |
+
with gr.Column():
|
61 |
+
gr.Markdown("## Upload Audio or Enter YT URL")
|
62 |
+
audio_input = gr.Audio(sources=["upload","microphone" ], type="filepath", label="Record or Upload Audio 🎤", min_width=50)
|
63 |
+
youtube_url = gr.Textbox(label="Enter YouTube Link", min_width=50)
|
64 |
+
video_player = gr.HTML("")
|
65 |
+
youtube_url.change(update_video, inputs=[youtube_url], outputs=[video_player])
|
66 |
+
|
67 |
+
with gr.Column():
|
68 |
+
with gr.Row():
|
69 |
+
gr.Markdown("" )
|
70 |
+
transcribe_button = gr.Button("Generate Transcription", interactive=True,size= "sm", min_width=800)
|
71 |
+
detected_language = gr.Textbox(label="Detected Language", interactive=False, min_width=400)
|
72 |
+
transcription_output = gr.Textbox(label="Transcription", interactive=False, min_width=400)
|
73 |
+
|
74 |
+
with gr.Row():
|
75 |
+
language_selector = gr.Dropdown([
|
76 |
+
'Assamese', 'Bengali', 'Bodo', 'Dogri', 'English', 'Gujarati', 'Hindi', 'Kannada', 'Kashmiri(Perso-Arabic script)',
|
77 |
+
'Kashmiri(Devanagari script)', 'Konkani', 'Maithili', 'Malayalam', 'Manipuri(Bengali script)',
|
78 |
+
'Manipuri(Meitei script)', 'Marathi', 'Nepali', 'Odia', 'Punjabi', 'Sanskrit', 'Santali(Ol Chiki script)',
|
79 |
+
'Sindhi(Perso-Arabic script)', 'Sindhi(Devanagari script)', 'Tamil', 'Telugu', 'Urdu'
|
80 |
+
], label="Select Target Language", min_width=400)
|
81 |
+
with gr.Row():
|
82 |
+
translate_button = gr.Button("Generate Translation", interactive=True, size= "sm",min_width=350)
|
83 |
+
with gr.Row():
|
84 |
+
translation_output = gr.Textbox(label="Translation", interactive=False, min_width=400)
|
85 |
+
|
86 |
+
with gr.Row():
|
87 |
+
subtitle_button = gr.Button("Generate Subtitles", interactive=True,size= "sm",min_width=350)
|
88 |
+
with gr.Row():
|
89 |
+
subtitle_status = gr.Textbox(label="Subtitle Status", interactive=False, min_width=400)
|
90 |
+
subtitle_download = gr.File(label="Download Subtitles", visible=True, min_width=400)
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
|
95 |
+
'''with gr.Column():
|
96 |
+
transliterate_button = gr.Button("Generate Transliteration", interactive=True)
|
97 |
+
transliteration_output = gr.Textbox(label="Transliteration", interactive=False)'''
|
98 |
+
|
99 |
+
transcribe_button.click(
|
100 |
+
process_transcription,
|
101 |
+
inputs=[audio_input, youtube_url],
|
102 |
+
outputs=[detected_language, transcription_output]
|
103 |
+
)
|
104 |
+
|
105 |
+
translate_button.click(
|
106 |
+
process_translation,
|
107 |
+
inputs=[transcription_output, language_selector, detected_language],
|
108 |
+
outputs=[translation_output]
|
109 |
+
)
|
110 |
+
|
111 |
+
subtitle_button.click(
|
112 |
+
process_subtitle,
|
113 |
+
inputs=[transcription_output, translation_output],
|
114 |
+
outputs=[subtitle_status, subtitle_download]
|
115 |
+
)
|
116 |
+
|
117 |
+
'''transliterate_button.click(
|
118 |
+
process_transliteration,
|
119 |
+
inputs=[translation_output],
|
120 |
+
outputs=[transliteration_output]
|
121 |
+
)'''
|
122 |
+
# Add CSS for custom styling
|
123 |
+
demo.css = """
|
124 |
+
#title {
|
125 |
+
text-align: center;
|
126 |
+
font-size: 36px;
|
127 |
+
font-weight: bold;
|
128 |
+
width: 100%;
|
129 |
+
}
|
130 |
+
|
131 |
+
#label-center {
|
132 |
+
text-align: center;
|
133 |
+
font-size: 18px;
|
134 |
+
font-weight: bold;
|
135 |
+
width: 50%;
|
136 |
+
margin: auto;
|
137 |
+
|
138 |
+
}
|
139 |
+
|
140 |
+
|
141 |
+
"""
|
142 |
+
# Launch the Gradio Apps
|
143 |
+
if __name__ == "__main__":
|
144 |
+
demo.launch(share=True, debug=True,pwa=True)
|
create_srt.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def create_srt(transcription, translated_text):
|
3 |
+
srt_content = ""
|
4 |
+
for idx, (trans, trans_tr) in enumerate(zip(transcription.split("\n"), translated_text.split("\n"))):
|
5 |
+
start_time = f"00:00:{idx:02d},000"
|
6 |
+
end_time = f"00:00:{(idx + 1):02d},000"
|
7 |
+
srt_content += f"{idx+1}\n{start_time} --> {end_time}\n{trans}\n{trans_tr}\n\n"
|
8 |
+
|
9 |
+
# Save to a file
|
10 |
+
subtitle_file = "translated_subtitles.srt"
|
11 |
+
with open(subtitle_file, "w", encoding="utf-8") as f:
|
12 |
+
f.write(srt_content)
|
13 |
+
|
14 |
+
return subtitle_file
|
15 |
+
|
16 |
+
|
custom_theme.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import annotations
|
3 |
+
from typing import Iterable
|
4 |
+
|
5 |
+
from gradio.themes.base import Base
|
6 |
+
from gradio.themes.utils import colors, fonts, sizes
|
7 |
+
|
8 |
+
|
9 |
+
class CustomTheme(Base):
|
10 |
+
def __init__(
|
11 |
+
self,
|
12 |
+
*,
|
13 |
+
primary_hue: colors.Color | str = colors.rose,
|
14 |
+
secondary_hue: colors.Color | str = colors.amber,
|
15 |
+
neutral_hue: colors.Color | str = colors.gray,
|
16 |
+
spacing_size: sizes.Size | str = sizes.spacing_md,
|
17 |
+
radius_size: sizes.Size | str = sizes.radius_md,
|
18 |
+
text_size: sizes.Size | str = sizes.text_lg,
|
19 |
+
font: fonts.Font
|
20 |
+
| str
|
21 |
+
| Iterable[fonts.Font | str] = (
|
22 |
+
fonts.GoogleFont("Quicksand"),
|
23 |
+
"ui-sans-serif",
|
24 |
+
"sans-serif",
|
25 |
+
),
|
26 |
+
font_mono: fonts.Font
|
27 |
+
| str
|
28 |
+
| Iterable[fonts.Font | str] = (
|
29 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
30 |
+
"ui-monospace",
|
31 |
+
"monospace",
|
32 |
+
),
|
33 |
+
):
|
34 |
+
super().__init__(
|
35 |
+
primary_hue=primary_hue,
|
36 |
+
secondary_hue=secondary_hue,
|
37 |
+
neutral_hue=neutral_hue,
|
38 |
+
spacing_size=spacing_size,
|
39 |
+
radius_size=radius_size,
|
40 |
+
text_size=text_size,
|
41 |
+
font=font,
|
42 |
+
font_mono=font_mono,
|
43 |
+
)
|
44 |
+
super().set(
|
45 |
+
# 🌅 **New Elegant Background**
|
46 |
+
body_background_fill="""
|
47 |
+
radial-gradient(circle at top left, *primary_200, *secondary_100),
|
48 |
+
linear-gradient(120deg, *primary_300, *secondary_200)
|
49 |
+
""",
|
50 |
+
body_background_fill_dark="""
|
51 |
+
radial-gradient(circle at bottom right, *primary_800, *secondary_600),
|
52 |
+
linear-gradient(120deg, *primary_900, *secondary_700)
|
53 |
+
""",
|
54 |
+
|
55 |
+
# 🔘 Dark Grey Buttons with Hover Effect
|
56 |
+
button_primary_background_fill="#4A4A4A",
|
57 |
+
button_primary_background_fill_hover="#6A6A6A",
|
58 |
+
button_primary_text_color="white",
|
59 |
+
button_primary_background_fill_dark="#3A3A3A",
|
60 |
+
button_primary_shadow="0px 4px 12px rgba(0,0,0,0.3)",
|
61 |
+
|
62 |
+
# 🖱️ Other UI Elements
|
63 |
+
slider_color="*secondary_300",
|
64 |
+
slider_color_dark="*secondary_600",
|
65 |
+
block_title_text_weight="600",
|
66 |
+
block_border_width="3px",
|
67 |
+
block_shadow="*shadow_drop_lg",
|
68 |
+
button_large_padding="32px",
|
69 |
+
)
|
install.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
echo "Updating package list..."
|
4 |
+
sudo apt update
|
5 |
+
|
6 |
+
echo "Installing required system dependencies..."
|
7 |
+
sudo apt install -y git wget ffmpeg # Added ffmpeg
|
8 |
+
|
9 |
+
|
10 |
+
echo "Upgrading pip..."
|
11 |
+
pip install --upgrade pip
|
12 |
+
|
13 |
+
echo "Installing Python dependencies..."
|
14 |
+
pip install --upgrade --no-deps --force-reinstall git+https://github.com/openai/whisper.git
|
15 |
+
pip install transformers torch numpy gitpython
|
16 |
+
pip install whisper # Explicitly installing Whisper
|
17 |
+
pip install bitsandbytes accelerate
|
18 |
+
pip install gradio
|
19 |
+
pip install yt_dlp tiktoken
|
20 |
+
echo "Cloning IndicTrans2 repository..."
|
21 |
+
if [ ! -d "IndicTrans2" ]; then
|
22 |
+
git clone https://github.com/AI4Bharat/IndicTrans2
|
23 |
+
else
|
24 |
+
echo "IndicTrans2 already exists, skipping clone."
|
25 |
+
fi
|
26 |
+
|
27 |
+
echo "Navigating to IndicTrans2 directory..."
|
28 |
+
cd IndicTrans2/huggingface_interface || exit
|
29 |
+
|
30 |
+
echo "Running IndicTrans2 install script..."
|
31 |
+
bash install.sh
|
32 |
+
cd..
|
33 |
+
echo "Setup complete! You can now use Whisper, IndicTrans2, and Gradio."
|
process_yt_video.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Function to extract video ID from YouTube URL
|
2 |
+
import re
|
3 |
+
import yt_dlp
|
4 |
+
|
5 |
+
def get_video_id(youtube_url):
|
6 |
+
match = re.search(r"(?:v=|\/)([a-zA-Z0-9_-]{11})", youtube_url)
|
7 |
+
return match.group(0) if match else None
|
8 |
+
|
9 |
+
# Function to generate YouTube embed URL
|
10 |
+
def get_embed_url(youtube_url):
|
11 |
+
video_id = get_video_id(youtube_url)
|
12 |
+
if video_id:
|
13 |
+
return f"https://www.youtube.com/embed/{video_id}"
|
14 |
+
return None
|
15 |
+
# Function to download audio
|
16 |
+
def download_audio(youtube_url):
|
17 |
+
video_id = get_video_id(youtube_url)
|
18 |
+
if not video_id:
|
19 |
+
|
20 |
+
return None, None # Invalid URL
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
ydl_opts = {
|
25 |
+
'format': 'bestaudio/best',
|
26 |
+
'outtmpl': 'temp_audio.%(ext)s',
|
27 |
+
'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192'}],
|
28 |
+
'quiet': True,
|
29 |
+
}
|
30 |
+
|
31 |
+
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
32 |
+
info = ydl.extract_info(youtube_url, download=True)
|
33 |
+
audio_path = "temp_audio.mp3"
|
34 |
+
|
35 |
+
return audio_path
|
transcription.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import whisper
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from whisper import load_model, transcribe
|
5 |
+
from whisper.audio import load_audio
|
6 |
+
|
7 |
+
def transcribe_audio(audio):
|
8 |
+
model = load_model("small")
|
9 |
+
#audio_path = "/content/bharat.mp3"
|
10 |
+
#audio = load_audio(audio_path)
|
11 |
+
result = transcribe(model, audio)
|
12 |
+
|
13 |
+
|
14 |
+
detected_language = result.get("language")
|
15 |
+
whisper_to_indictrans2 = {
|
16 |
+
"as": "asm_Beng", # Assamese
|
17 |
+
"bn": "ben_Beng", # Bengali
|
18 |
+
"brx": "brx_Deva", # Bodo
|
19 |
+
"doi": "doi_Deva", # Dogri
|
20 |
+
"gu": "guj_Gujr", # Gujarati
|
21 |
+
"hi": "hin_Deva", # Hindi
|
22 |
+
"kn": "kan_Knda", # Kannada
|
23 |
+
"ks": "kas_Arab", # Kashmiri (Perso-Arabic script)
|
24 |
+
"ks_Deva": "kas_Deva", # Kashmiri (Devanagari script)
|
25 |
+
"kok": "kok_Deva", # Konkani
|
26 |
+
"mai": "mai_Deva", # Maithili
|
27 |
+
"ml": "mal_Mlym", # Malayalam
|
28 |
+
"mni": "mni_Beng", # Manipuri (Bengali script)
|
29 |
+
"mni_Mtei": "mni_Mtei", # Manipuri (Meitei script)
|
30 |
+
"mr": "mar_Deva", # Marathi
|
31 |
+
"ne": "nep_Deva", # Nepali
|
32 |
+
"or": "ory_Orya", # Odia
|
33 |
+
"pa": "pan_Guru", # Punjabi
|
34 |
+
"sa": "san_Deva", # Sanskrit
|
35 |
+
"sat": "sat_Olck", # Santali (Ol Chiki script)
|
36 |
+
"sd": "snd_Arab", # Sindhi (Perso-Arabic script)
|
37 |
+
"sd_Deva": "snd_Deva", # Sindhi (Devanagari script)
|
38 |
+
"ta": "tam_Taml", # Tamil
|
39 |
+
"te": "tel_Telu", # Telugu
|
40 |
+
"ur": "urd_Arab", # Urdu
|
41 |
+
"en": "eng_Latn",
|
42 |
+
}
|
43 |
+
if detected_language in whisper_to_indictrans2.keys():
|
44 |
+
detected_language = whisper_to_indictrans2[detected_language]
|
45 |
+
elif detected_language not in whisper_to_indictrans2.keys():
|
46 |
+
return "Unknown language detected",None
|
47 |
+
|
48 |
+
|
49 |
+
transcription = result.get("text") # Adjust key if necessary
|
50 |
+
|
51 |
+
return detected_language, transcription
|
52 |
+
|
translation.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
|
6 |
+
# Get the absolute path of IndicTransToolkit
|
7 |
+
indictrans_path = "/content/Voice-to-Text-Translation-System-Leveraging-Whisper-and-IndicTrans2/IndicTrans2/huggingface_interface/IndicTransToolkit/IndicTransToolkit"
|
8 |
+
sys.path.append(indictrans_path)
|
9 |
+
|
10 |
+
from processor import IndicProcessor
|
11 |
+
|
12 |
+
# Check if GPU is available
|
13 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
14 |
+
|
15 |
+
|
16 |
+
def translate_text(transcription, target_lang, src_lang):
|
17 |
+
mapping = {
|
18 |
+
"Assamese": "asm_Beng", "Bengali": "ben_Beng", "Bodo": "brx_Deva", "Dogri": "doi_Deva",
|
19 |
+
"Gujarati": "guj_Gujr", "Hindi": "hin_Deva", "Kannada": "kan_Knda",
|
20 |
+
"Kashmiri(Perso-Arabic script)": "kas_Arab", "Kashmiri(Devanagari script)": "kas_Deva",
|
21 |
+
"Konkani": "kok_Deva", "Maithili": "mai_Deva", "Malayalam": "mal_Mlym",
|
22 |
+
"Manipuri(Bengali script)": "mni_Beng", "Manipuri(Meitei script)": "mni_Mtei",
|
23 |
+
"Marathi": "mar_Deva", "Nepali": "nep_Deva", "Odia": "ory_Orya",
|
24 |
+
"Punjabi": "pan_Guru", "Sanskrit": "san_Deva", "Santali(Ol Chiki script)": "sat_Olck",
|
25 |
+
"Sindhi(Perso-Arabic script)": "snd_Arab", "Sindhi(Devanagari script)": "snd_Deva",
|
26 |
+
"Tamil": "tam_Taml", "Telugu": "tel_Telu", "Urdu": "urd_Arab","English":"eng_Latn",
|
27 |
+
}
|
28 |
+
if target_lang in mapping:
|
29 |
+
tgt_lang = mapping[target_lang]
|
30 |
+
|
31 |
+
if src_lang == tgt_lang:
|
32 |
+
return "Detected Language and Target Language cannot be same"
|
33 |
+
|
34 |
+
if src_lang == "eng_Latn":
|
35 |
+
model_name = "prajdabre/rotary-indictrans2-en-indic-1B"
|
36 |
+
else:
|
37 |
+
model1_name ="prajdabre/rotary-indictrans2-indic-en-1B"
|
38 |
+
model2_name = "prajdabre/rotary-indictrans2-en-indic-1B"
|
39 |
+
translations = indic_indic(model1_name,model2_name, src_lang, target_lang,transcription)
|
40 |
+
return translations
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
46 |
+
|
47 |
+
# Load model in 8-bit quantization
|
48 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
49 |
+
model_name,
|
50 |
+
trust_remote_code=True,
|
51 |
+
torch_dtype=torch.float16,
|
52 |
+
#load_in_8bit=True,
|
53 |
+
attn_implementation="flash_attention_2"
|
54 |
+
).to(DEVICE)
|
55 |
+
|
56 |
+
ip = IndicProcessor(inference=True)
|
57 |
+
|
58 |
+
input_sentences = [transcription]
|
59 |
+
|
60 |
+
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
|
61 |
+
|
62 |
+
|
63 |
+
# Tokenize the sentences and generate input encodings
|
64 |
+
inputs = tokenizer(
|
65 |
+
batch,
|
66 |
+
truncation=True,
|
67 |
+
padding="longest",
|
68 |
+
return_tensors="pt",
|
69 |
+
max_length=2048,
|
70 |
+
|
71 |
+
)
|
72 |
+
|
73 |
+
# Move inputs to the correct device (only inputs, NOT model)
|
74 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
75 |
+
|
76 |
+
# Generate translations using the model
|
77 |
+
with torch.inference_mode():
|
78 |
+
generated_tokens = model.generate(
|
79 |
+
**inputs,
|
80 |
+
num_beams=5,
|
81 |
+
length_penalty=1.5,
|
82 |
+
repetition_penalty=2.0,
|
83 |
+
num_return_sequences=1,
|
84 |
+
max_new_tokens=2048,
|
85 |
+
early_stopping=True
|
86 |
+
)
|
87 |
+
|
88 |
+
# Move generated tokens to CPU before decoding
|
89 |
+
generated_tokens = generated_tokens.cpu().tolist()
|
90 |
+
|
91 |
+
# Decode the generated tokens into text
|
92 |
+
with tokenizer.as_target_tokenizer():
|
93 |
+
generated_tokens = tokenizer.batch_decode(
|
94 |
+
generated_tokens,
|
95 |
+
skip_special_tokens=True,
|
96 |
+
clean_up_tokenization_spaces=True
|
97 |
+
)
|
98 |
+
|
99 |
+
# Postprocess the translations
|
100 |
+
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
101 |
+
print(type(translations))
|
102 |
+
translations =str(translations).strip("'")
|
103 |
+
return translations
|
104 |
+
def indic_indic(model1_name,model2_name,src_lang,tgt_lang,transcription,intermediate_lng ="eng_Latn",):
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained(model1_name, trust_remote_code=True)
|
106 |
+
|
107 |
+
# Load model in 8-bit quantization
|
108 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
109 |
+
model1_name,
|
110 |
+
trust_remote_code=True,
|
111 |
+
torch_dtype=torch.float16,
|
112 |
+
#load_in_8bit=True,
|
113 |
+
attn_implementation="flash_attention_2"
|
114 |
+
).to(DEVICE)
|
115 |
+
|
116 |
+
ip = IndicProcessor(inference=True)
|
117 |
+
|
118 |
+
input_sentences = [transcription]
|
119 |
+
|
120 |
+
batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=intermediate_lng)
|
121 |
+
|
122 |
+
# Tokenize the sentences and generate input encodings
|
123 |
+
inputs = tokenizer(
|
124 |
+
batch,
|
125 |
+
truncation=True,
|
126 |
+
padding="longest",
|
127 |
+
return_tensors="pt",
|
128 |
+
max_length=2048,
|
129 |
+
)
|
130 |
+
|
131 |
+
# Move inputs to the correct device (only inputs, NOT model)
|
132 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
133 |
+
|
134 |
+
# Generate translations using the model
|
135 |
+
with torch.inference_mode():
|
136 |
+
generated_tokens = model.generate(
|
137 |
+
**inputs,
|
138 |
+
num_beams=10,
|
139 |
+
length_penalty=1.5,
|
140 |
+
repetition_penalty=2.0,
|
141 |
+
num_return_sequences=1,
|
142 |
+
max_new_tokens=2048,
|
143 |
+
early_stopping=True
|
144 |
+
)
|
145 |
+
|
146 |
+
# Move generated tokens to CPU before decoding
|
147 |
+
generated_tokens = generated_tokens.cpu().tolist()
|
148 |
+
|
149 |
+
# Decode the generated tokens into text
|
150 |
+
with tokenizer.as_target_tokenizer():
|
151 |
+
generated_tokens = tokenizer.batch_decode(
|
152 |
+
generated_tokens,
|
153 |
+
skip_special_tokens=True,
|
154 |
+
clean_up_tokenization_spaces=True
|
155 |
+
)
|
156 |
+
|
157 |
+
# Postprocess the translations
|
158 |
+
translations1 = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
159 |
+
|
160 |
+
translations1 =str(translations).strip("'")
|
161 |
+
tokenizer = AutoTokenizer.from_pretrained(model2_name, trust_remote_code=True)
|
162 |
+
|
163 |
+
# Load model in 8-bit quantization
|
164 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
165 |
+
model2_name,
|
166 |
+
trust_remote_code=True,
|
167 |
+
torch_dtype=torch.float16,
|
168 |
+
#load_in_8bit=True,
|
169 |
+
attn_implementation="flash_attention_2"
|
170 |
+
).to(DEVICE)
|
171 |
+
|
172 |
+
ip = IndicProcessor(inference=True)
|
173 |
+
|
174 |
+
input_sentences = [translations1]
|
175 |
+
|
176 |
+
batch = ip.preprocess_batch(input_sentences, src_lang=intermediate_lng, tgt_lang=tgt_lang)
|
177 |
+
|
178 |
+
# Tokenize the sentences and generate input encodings
|
179 |
+
inputs = tokenizer(
|
180 |
+
batch,
|
181 |
+
truncation=True,
|
182 |
+
padding="longest",
|
183 |
+
return_tensors="pt",
|
184 |
+
max_length=2048,
|
185 |
+
)
|
186 |
+
|
187 |
+
# Move inputs to the correct device (only inputs, NOT model)
|
188 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
189 |
+
|
190 |
+
# Generate translations using the model
|
191 |
+
with torch.inference_mode():
|
192 |
+
generated_tokens = model.generate(
|
193 |
+
**inputs,
|
194 |
+
num_beams=10,
|
195 |
+
length_penalty=1.5,
|
196 |
+
repetition_penalty=2.0,
|
197 |
+
num_return_sequences=1,
|
198 |
+
max_new_tokens=2048,
|
199 |
+
early_stopping=True
|
200 |
+
)
|
201 |
+
|
202 |
+
|
203 |
+
# Move generated tokens to CPU before decoding
|
204 |
+
generated_tokens = generated_tokens.cpu().tolist()
|
205 |
+
|
206 |
+
# Decode the generated tokens into text
|
207 |
+
with tokenizer.as_target_tokenizer():
|
208 |
+
generated_tokens = tokenizer.batch_decode(
|
209 |
+
generated_tokens,
|
210 |
+
skip_special_tokens=True,
|
211 |
+
clean_up_tokenization_spaces=True
|
212 |
+
)
|
213 |
+
|
214 |
+
# Postprocess the translations
|
215 |
+
translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
|
216 |
+
|
217 |
+
|
218 |
+
return translations
|
219 |
+
|
transliteration.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#Under development
|