injilashah commited on
Commit
4c6879b
·
verified ·
1 Parent(s): b111a0c

Upload 8 files

Browse files
Files changed (8) hide show
  1. app.py +144 -0
  2. create_srt.py +16 -0
  3. custom_theme.py +69 -0
  4. install.sh +33 -0
  5. process_yt_video.py +35 -0
  6. transcription.py +52 -0
  7. translation.py +219 -0
  8. transliteration.py +1 -0
app.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transcription import transcribe_audio
3
+ from translation import translate_text
4
+
5
+ from process_yt_video import download_audio,get_embed_url
6
+ from create_srt import create_srt
7
+ from custom_theme import CustomTheme
8
+
9
+ stheme = CustomTheme()
10
+
11
+ # Function to handle transcription
12
+ def process_transcription(audio_file, youtube_url):
13
+ if youtube_url:
14
+ audio_file = download_audio(youtube_url)
15
+ print(f"Downloaded audio file from YouTube: {audio_file}")
16
+
17
+ if not audio_file:
18
+ return None, "No audio provided!", None, None
19
+
20
+ print(f"Processing audio file: {audio_file}")
21
+ detected_lang, transcription = transcribe_audio(audio_file)
22
+
23
+ if not transcription:
24
+ return "Error in transcription", None
25
+
26
+ return detected_lang, transcription
27
+
28
+ # Function to handle translation
29
+ def process_translation(transcription, target_lang, detected_lang):
30
+ if not transcription:
31
+ return "Please transcribe first!"
32
+
33
+ translated_text = translate_text(transcription, target_lang, detected_lang)
34
+ return translated_text
35
+
36
+ # Function to handle subtitle generation
37
+ def process_subtitle(transcription, translation):
38
+ if not transcription or not translation:
39
+ return "Please transcribe and translate first!", None
40
+
41
+ subtitle_file = create_srt(transcription, translation)
42
+ return "Subtitle generated successfully!", subtitle_file
43
+
44
+ # Function to handle transliteration
45
+ '''def process_transliteration(translated_text):
46
+ if not translated_text:
47
+ return "Please translate first!"
48
+ return "hello"'''
49
+
50
+ # Function to update embedded YouTube video player
51
+ def update_video(youtube_url):
52
+ embed_url = get_embed_url(youtube_url)
53
+ return f"<iframe width='560' height='315' src='{embed_url}' frameborder='0' allowfullscreen></iframe>" if embed_url else ""
54
+
55
+ with gr.Blocks(theme=stheme) as demo:
56
+ gr.Markdown("# Voice-to-Text Translation System", elem_id="title")
57
+
58
+ with gr.Row():
59
+
60
+ with gr.Column():
61
+ gr.Markdown("## Upload Audio or Enter YT URL")
62
+ audio_input = gr.Audio(sources=["upload","microphone" ], type="filepath", label="Record or Upload Audio 🎤", min_width=50)
63
+ youtube_url = gr.Textbox(label="Enter YouTube Link", min_width=50)
64
+ video_player = gr.HTML("")
65
+ youtube_url.change(update_video, inputs=[youtube_url], outputs=[video_player])
66
+
67
+ with gr.Column():
68
+ with gr.Row():
69
+ gr.Markdown("" )
70
+ transcribe_button = gr.Button("Generate Transcription", interactive=True,size= "sm", min_width=800)
71
+ detected_language = gr.Textbox(label="Detected Language", interactive=False, min_width=400)
72
+ transcription_output = gr.Textbox(label="Transcription", interactive=False, min_width=400)
73
+
74
+ with gr.Row():
75
+ language_selector = gr.Dropdown([
76
+ 'Assamese', 'Bengali', 'Bodo', 'Dogri', 'English', 'Gujarati', 'Hindi', 'Kannada', 'Kashmiri(Perso-Arabic script)',
77
+ 'Kashmiri(Devanagari script)', 'Konkani', 'Maithili', 'Malayalam', 'Manipuri(Bengali script)',
78
+ 'Manipuri(Meitei script)', 'Marathi', 'Nepali', 'Odia', 'Punjabi', 'Sanskrit', 'Santali(Ol Chiki script)',
79
+ 'Sindhi(Perso-Arabic script)', 'Sindhi(Devanagari script)', 'Tamil', 'Telugu', 'Urdu'
80
+ ], label="Select Target Language", min_width=400)
81
+ with gr.Row():
82
+ translate_button = gr.Button("Generate Translation", interactive=True, size= "sm",min_width=350)
83
+ with gr.Row():
84
+ translation_output = gr.Textbox(label="Translation", interactive=False, min_width=400)
85
+
86
+ with gr.Row():
87
+ subtitle_button = gr.Button("Generate Subtitles", interactive=True,size= "sm",min_width=350)
88
+ with gr.Row():
89
+ subtitle_status = gr.Textbox(label="Subtitle Status", interactive=False, min_width=400)
90
+ subtitle_download = gr.File(label="Download Subtitles", visible=True, min_width=400)
91
+
92
+
93
+
94
+
95
+ '''with gr.Column():
96
+ transliterate_button = gr.Button("Generate Transliteration", interactive=True)
97
+ transliteration_output = gr.Textbox(label="Transliteration", interactive=False)'''
98
+
99
+ transcribe_button.click(
100
+ process_transcription,
101
+ inputs=[audio_input, youtube_url],
102
+ outputs=[detected_language, transcription_output]
103
+ )
104
+
105
+ translate_button.click(
106
+ process_translation,
107
+ inputs=[transcription_output, language_selector, detected_language],
108
+ outputs=[translation_output]
109
+ )
110
+
111
+ subtitle_button.click(
112
+ process_subtitle,
113
+ inputs=[transcription_output, translation_output],
114
+ outputs=[subtitle_status, subtitle_download]
115
+ )
116
+
117
+ '''transliterate_button.click(
118
+ process_transliteration,
119
+ inputs=[translation_output],
120
+ outputs=[transliteration_output]
121
+ )'''
122
+ # Add CSS for custom styling
123
+ demo.css = """
124
+ #title {
125
+ text-align: center;
126
+ font-size: 36px;
127
+ font-weight: bold;
128
+ width: 100%;
129
+ }
130
+
131
+ #label-center {
132
+ text-align: center;
133
+ font-size: 18px;
134
+ font-weight: bold;
135
+ width: 50%;
136
+ margin: auto;
137
+
138
+ }
139
+
140
+
141
+ """
142
+ # Launch the Gradio Apps
143
+ if __name__ == "__main__":
144
+ demo.launch(share=True, debug=True,pwa=True)
create_srt.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ def create_srt(transcription, translated_text):
3
+ srt_content = ""
4
+ for idx, (trans, trans_tr) in enumerate(zip(transcription.split("\n"), translated_text.split("\n"))):
5
+ start_time = f"00:00:{idx:02d},000"
6
+ end_time = f"00:00:{(idx + 1):02d},000"
7
+ srt_content += f"{idx+1}\n{start_time} --> {end_time}\n{trans}\n{trans_tr}\n\n"
8
+
9
+ # Save to a file
10
+ subtitle_file = "translated_subtitles.srt"
11
+ with open(subtitle_file, "w", encoding="utf-8") as f:
12
+ f.write(srt_content)
13
+
14
+ return subtitle_file
15
+
16
+
custom_theme.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+ from typing import Iterable
4
+
5
+ from gradio.themes.base import Base
6
+ from gradio.themes.utils import colors, fonts, sizes
7
+
8
+
9
+ class CustomTheme(Base):
10
+ def __init__(
11
+ self,
12
+ *,
13
+ primary_hue: colors.Color | str = colors.rose,
14
+ secondary_hue: colors.Color | str = colors.amber,
15
+ neutral_hue: colors.Color | str = colors.gray,
16
+ spacing_size: sizes.Size | str = sizes.spacing_md,
17
+ radius_size: sizes.Size | str = sizes.radius_md,
18
+ text_size: sizes.Size | str = sizes.text_lg,
19
+ font: fonts.Font
20
+ | str
21
+ | Iterable[fonts.Font | str] = (
22
+ fonts.GoogleFont("Quicksand"),
23
+ "ui-sans-serif",
24
+ "sans-serif",
25
+ ),
26
+ font_mono: fonts.Font
27
+ | str
28
+ | Iterable[fonts.Font | str] = (
29
+ fonts.GoogleFont("IBM Plex Mono"),
30
+ "ui-monospace",
31
+ "monospace",
32
+ ),
33
+ ):
34
+ super().__init__(
35
+ primary_hue=primary_hue,
36
+ secondary_hue=secondary_hue,
37
+ neutral_hue=neutral_hue,
38
+ spacing_size=spacing_size,
39
+ radius_size=radius_size,
40
+ text_size=text_size,
41
+ font=font,
42
+ font_mono=font_mono,
43
+ )
44
+ super().set(
45
+ # 🌅 **New Elegant Background**
46
+ body_background_fill="""
47
+ radial-gradient(circle at top left, *primary_200, *secondary_100),
48
+ linear-gradient(120deg, *primary_300, *secondary_200)
49
+ """,
50
+ body_background_fill_dark="""
51
+ radial-gradient(circle at bottom right, *primary_800, *secondary_600),
52
+ linear-gradient(120deg, *primary_900, *secondary_700)
53
+ """,
54
+
55
+ # 🔘 Dark Grey Buttons with Hover Effect
56
+ button_primary_background_fill="#4A4A4A",
57
+ button_primary_background_fill_hover="#6A6A6A",
58
+ button_primary_text_color="white",
59
+ button_primary_background_fill_dark="#3A3A3A",
60
+ button_primary_shadow="0px 4px 12px rgba(0,0,0,0.3)",
61
+
62
+ # 🖱️ Other UI Elements
63
+ slider_color="*secondary_300",
64
+ slider_color_dark="*secondary_600",
65
+ block_title_text_weight="600",
66
+ block_border_width="3px",
67
+ block_shadow="*shadow_drop_lg",
68
+ button_large_padding="32px",
69
+ )
install.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ echo "Updating package list..."
4
+ sudo apt update
5
+
6
+ echo "Installing required system dependencies..."
7
+ sudo apt install -y git wget ffmpeg # Added ffmpeg
8
+
9
+
10
+ echo "Upgrading pip..."
11
+ pip install --upgrade pip
12
+
13
+ echo "Installing Python dependencies..."
14
+ pip install --upgrade --no-deps --force-reinstall git+https://github.com/openai/whisper.git
15
+ pip install transformers torch numpy gitpython
16
+ pip install whisper # Explicitly installing Whisper
17
+ pip install bitsandbytes accelerate
18
+ pip install gradio
19
+ pip install yt_dlp tiktoken
20
+ echo "Cloning IndicTrans2 repository..."
21
+ if [ ! -d "IndicTrans2" ]; then
22
+ git clone https://github.com/AI4Bharat/IndicTrans2
23
+ else
24
+ echo "IndicTrans2 already exists, skipping clone."
25
+ fi
26
+
27
+ echo "Navigating to IndicTrans2 directory..."
28
+ cd IndicTrans2/huggingface_interface || exit
29
+
30
+ echo "Running IndicTrans2 install script..."
31
+ bash install.sh
32
+ cd..
33
+ echo "Setup complete! You can now use Whisper, IndicTrans2, and Gradio."
process_yt_video.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Function to extract video ID from YouTube URL
2
+ import re
3
+ import yt_dlp
4
+
5
+ def get_video_id(youtube_url):
6
+ match = re.search(r"(?:v=|\/)([a-zA-Z0-9_-]{11})", youtube_url)
7
+ return match.group(0) if match else None
8
+
9
+ # Function to generate YouTube embed URL
10
+ def get_embed_url(youtube_url):
11
+ video_id = get_video_id(youtube_url)
12
+ if video_id:
13
+ return f"https://www.youtube.com/embed/{video_id}"
14
+ return None
15
+ # Function to download audio
16
+ def download_audio(youtube_url):
17
+ video_id = get_video_id(youtube_url)
18
+ if not video_id:
19
+
20
+ return None, None # Invalid URL
21
+
22
+
23
+
24
+ ydl_opts = {
25
+ 'format': 'bestaudio/best',
26
+ 'outtmpl': 'temp_audio.%(ext)s',
27
+ 'postprocessors': [{'key': 'FFmpegExtractAudio', 'preferredcodec': 'mp3', 'preferredquality': '192'}],
28
+ 'quiet': True,
29
+ }
30
+
31
+ with yt_dlp.YoutubeDL(ydl_opts) as ydl:
32
+ info = ydl.extract_info(youtube_url, download=True)
33
+ audio_path = "temp_audio.mp3"
34
+
35
+ return audio_path
transcription.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import whisper
2
+ import numpy as np
3
+
4
+ from whisper import load_model, transcribe
5
+ from whisper.audio import load_audio
6
+
7
+ def transcribe_audio(audio):
8
+ model = load_model("small")
9
+ #audio_path = "/content/bharat.mp3"
10
+ #audio = load_audio(audio_path)
11
+ result = transcribe(model, audio)
12
+
13
+
14
+ detected_language = result.get("language")
15
+ whisper_to_indictrans2 = {
16
+ "as": "asm_Beng", # Assamese
17
+ "bn": "ben_Beng", # Bengali
18
+ "brx": "brx_Deva", # Bodo
19
+ "doi": "doi_Deva", # Dogri
20
+ "gu": "guj_Gujr", # Gujarati
21
+ "hi": "hin_Deva", # Hindi
22
+ "kn": "kan_Knda", # Kannada
23
+ "ks": "kas_Arab", # Kashmiri (Perso-Arabic script)
24
+ "ks_Deva": "kas_Deva", # Kashmiri (Devanagari script)
25
+ "kok": "kok_Deva", # Konkani
26
+ "mai": "mai_Deva", # Maithili
27
+ "ml": "mal_Mlym", # Malayalam
28
+ "mni": "mni_Beng", # Manipuri (Bengali script)
29
+ "mni_Mtei": "mni_Mtei", # Manipuri (Meitei script)
30
+ "mr": "mar_Deva", # Marathi
31
+ "ne": "nep_Deva", # Nepali
32
+ "or": "ory_Orya", # Odia
33
+ "pa": "pan_Guru", # Punjabi
34
+ "sa": "san_Deva", # Sanskrit
35
+ "sat": "sat_Olck", # Santali (Ol Chiki script)
36
+ "sd": "snd_Arab", # Sindhi (Perso-Arabic script)
37
+ "sd_Deva": "snd_Deva", # Sindhi (Devanagari script)
38
+ "ta": "tam_Taml", # Tamil
39
+ "te": "tel_Telu", # Telugu
40
+ "ur": "urd_Arab", # Urdu
41
+ "en": "eng_Latn",
42
+ }
43
+ if detected_language in whisper_to_indictrans2.keys():
44
+ detected_language = whisper_to_indictrans2[detected_language]
45
+ elif detected_language not in whisper_to_indictrans2.keys():
46
+ return "Unknown language detected",None
47
+
48
+
49
+ transcription = result.get("text") # Adjust key if necessary
50
+
51
+ return detected_language, transcription
52
+
translation.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
+ import sys
4
+ import os
5
+
6
+ # Get the absolute path of IndicTransToolkit
7
+ indictrans_path = "/content/Voice-to-Text-Translation-System-Leveraging-Whisper-and-IndicTrans2/IndicTrans2/huggingface_interface/IndicTransToolkit/IndicTransToolkit"
8
+ sys.path.append(indictrans_path)
9
+
10
+ from processor import IndicProcessor
11
+
12
+ # Check if GPU is available
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+
15
+
16
+ def translate_text(transcription, target_lang, src_lang):
17
+ mapping = {
18
+ "Assamese": "asm_Beng", "Bengali": "ben_Beng", "Bodo": "brx_Deva", "Dogri": "doi_Deva",
19
+ "Gujarati": "guj_Gujr", "Hindi": "hin_Deva", "Kannada": "kan_Knda",
20
+ "Kashmiri(Perso-Arabic script)": "kas_Arab", "Kashmiri(Devanagari script)": "kas_Deva",
21
+ "Konkani": "kok_Deva", "Maithili": "mai_Deva", "Malayalam": "mal_Mlym",
22
+ "Manipuri(Bengali script)": "mni_Beng", "Manipuri(Meitei script)": "mni_Mtei",
23
+ "Marathi": "mar_Deva", "Nepali": "nep_Deva", "Odia": "ory_Orya",
24
+ "Punjabi": "pan_Guru", "Sanskrit": "san_Deva", "Santali(Ol Chiki script)": "sat_Olck",
25
+ "Sindhi(Perso-Arabic script)": "snd_Arab", "Sindhi(Devanagari script)": "snd_Deva",
26
+ "Tamil": "tam_Taml", "Telugu": "tel_Telu", "Urdu": "urd_Arab","English":"eng_Latn",
27
+ }
28
+ if target_lang in mapping:
29
+ tgt_lang = mapping[target_lang]
30
+
31
+ if src_lang == tgt_lang:
32
+ return "Detected Language and Target Language cannot be same"
33
+
34
+ if src_lang == "eng_Latn":
35
+ model_name = "prajdabre/rotary-indictrans2-en-indic-1B"
36
+ else:
37
+ model1_name ="prajdabre/rotary-indictrans2-indic-en-1B"
38
+ model2_name = "prajdabre/rotary-indictrans2-en-indic-1B"
39
+ translations = indic_indic(model1_name,model2_name, src_lang, target_lang,transcription)
40
+ return translations
41
+
42
+
43
+
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
46
+
47
+ # Load model in 8-bit quantization
48
+ model = AutoModelForSeq2SeqLM.from_pretrained(
49
+ model_name,
50
+ trust_remote_code=True,
51
+ torch_dtype=torch.float16,
52
+ #load_in_8bit=True,
53
+ attn_implementation="flash_attention_2"
54
+ ).to(DEVICE)
55
+
56
+ ip = IndicProcessor(inference=True)
57
+
58
+ input_sentences = [transcription]
59
+
60
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=tgt_lang)
61
+
62
+
63
+ # Tokenize the sentences and generate input encodings
64
+ inputs = tokenizer(
65
+ batch,
66
+ truncation=True,
67
+ padding="longest",
68
+ return_tensors="pt",
69
+ max_length=2048,
70
+
71
+ )
72
+
73
+ # Move inputs to the correct device (only inputs, NOT model)
74
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
75
+
76
+ # Generate translations using the model
77
+ with torch.inference_mode():
78
+ generated_tokens = model.generate(
79
+ **inputs,
80
+ num_beams=5,
81
+ length_penalty=1.5,
82
+ repetition_penalty=2.0,
83
+ num_return_sequences=1,
84
+ max_new_tokens=2048,
85
+ early_stopping=True
86
+ )
87
+
88
+ # Move generated tokens to CPU before decoding
89
+ generated_tokens = generated_tokens.cpu().tolist()
90
+
91
+ # Decode the generated tokens into text
92
+ with tokenizer.as_target_tokenizer():
93
+ generated_tokens = tokenizer.batch_decode(
94
+ generated_tokens,
95
+ skip_special_tokens=True,
96
+ clean_up_tokenization_spaces=True
97
+ )
98
+
99
+ # Postprocess the translations
100
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
101
+ print(type(translations))
102
+ translations =str(translations).strip("'")
103
+ return translations
104
+ def indic_indic(model1_name,model2_name,src_lang,tgt_lang,transcription,intermediate_lng ="eng_Latn",):
105
+ tokenizer = AutoTokenizer.from_pretrained(model1_name, trust_remote_code=True)
106
+
107
+ # Load model in 8-bit quantization
108
+ model = AutoModelForSeq2SeqLM.from_pretrained(
109
+ model1_name,
110
+ trust_remote_code=True,
111
+ torch_dtype=torch.float16,
112
+ #load_in_8bit=True,
113
+ attn_implementation="flash_attention_2"
114
+ ).to(DEVICE)
115
+
116
+ ip = IndicProcessor(inference=True)
117
+
118
+ input_sentences = [transcription]
119
+
120
+ batch = ip.preprocess_batch(input_sentences, src_lang=src_lang, tgt_lang=intermediate_lng)
121
+
122
+ # Tokenize the sentences and generate input encodings
123
+ inputs = tokenizer(
124
+ batch,
125
+ truncation=True,
126
+ padding="longest",
127
+ return_tensors="pt",
128
+ max_length=2048,
129
+ )
130
+
131
+ # Move inputs to the correct device (only inputs, NOT model)
132
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
133
+
134
+ # Generate translations using the model
135
+ with torch.inference_mode():
136
+ generated_tokens = model.generate(
137
+ **inputs,
138
+ num_beams=10,
139
+ length_penalty=1.5,
140
+ repetition_penalty=2.0,
141
+ num_return_sequences=1,
142
+ max_new_tokens=2048,
143
+ early_stopping=True
144
+ )
145
+
146
+ # Move generated tokens to CPU before decoding
147
+ generated_tokens = generated_tokens.cpu().tolist()
148
+
149
+ # Decode the generated tokens into text
150
+ with tokenizer.as_target_tokenizer():
151
+ generated_tokens = tokenizer.batch_decode(
152
+ generated_tokens,
153
+ skip_special_tokens=True,
154
+ clean_up_tokenization_spaces=True
155
+ )
156
+
157
+ # Postprocess the translations
158
+ translations1 = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
159
+
160
+ translations1 =str(translations).strip("'")
161
+ tokenizer = AutoTokenizer.from_pretrained(model2_name, trust_remote_code=True)
162
+
163
+ # Load model in 8-bit quantization
164
+ model = AutoModelForSeq2SeqLM.from_pretrained(
165
+ model2_name,
166
+ trust_remote_code=True,
167
+ torch_dtype=torch.float16,
168
+ #load_in_8bit=True,
169
+ attn_implementation="flash_attention_2"
170
+ ).to(DEVICE)
171
+
172
+ ip = IndicProcessor(inference=True)
173
+
174
+ input_sentences = [translations1]
175
+
176
+ batch = ip.preprocess_batch(input_sentences, src_lang=intermediate_lng, tgt_lang=tgt_lang)
177
+
178
+ # Tokenize the sentences and generate input encodings
179
+ inputs = tokenizer(
180
+ batch,
181
+ truncation=True,
182
+ padding="longest",
183
+ return_tensors="pt",
184
+ max_length=2048,
185
+ )
186
+
187
+ # Move inputs to the correct device (only inputs, NOT model)
188
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
189
+
190
+ # Generate translations using the model
191
+ with torch.inference_mode():
192
+ generated_tokens = model.generate(
193
+ **inputs,
194
+ num_beams=10,
195
+ length_penalty=1.5,
196
+ repetition_penalty=2.0,
197
+ num_return_sequences=1,
198
+ max_new_tokens=2048,
199
+ early_stopping=True
200
+ )
201
+
202
+
203
+ # Move generated tokens to CPU before decoding
204
+ generated_tokens = generated_tokens.cpu().tolist()
205
+
206
+ # Decode the generated tokens into text
207
+ with tokenizer.as_target_tokenizer():
208
+ generated_tokens = tokenizer.batch_decode(
209
+ generated_tokens,
210
+ skip_special_tokens=True,
211
+ clean_up_tokenization_spaces=True
212
+ )
213
+
214
+ # Postprocess the translations
215
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
216
+
217
+
218
+ return translations
219
+
transliteration.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #Under development