Lingalingeswaran commited on
Commit
85dbb76
·
verified ·
1 Parent(s): 60fd575

Upload pipeline.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. pipeline.py +45 -0
pipeline.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import whisper
3
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
4
+ from googletrans import Translator
5
+ import torch
6
+
7
+ def load_models():
8
+ lang_detector = whisper.load_model("small")
9
+ tamil_processor = WhisperProcessor.from_pretrained("Lingalingeswaran/whisper-small-ta")
10
+ tamil_model = WhisperForConditionalGeneration.from_pretrained("Lingalingeswaran/whisper-small-ta")
11
+ sinhala_processor = WhisperProcessor.from_pretrained("Lingalingeswaran/whisper-small-sinhala")
12
+ sinhala_model = WhisperForConditionalGeneration.from_pretrained("Lingalingeswaran/whisper-small-sinhala")
13
+ english_model = whisper.load_model("small")
14
+ return lang_detector, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model
15
+
16
+ def detect_language(audio_file, lang_detector):
17
+ audio = whisper.load_audio(audio_file)
18
+ audio = whisper.pad_or_trim(audio)
19
+ mel = whisper.log_mel_spectrogram(audio).to(lang_detector.device)
20
+ _, probs = lang_detector.detect_language(mel)
21
+ return max(probs, key=probs.get)
22
+
23
+ def transcribe_audio(audio_file, detected_lang, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model):
24
+ if detected_lang == "ta":
25
+ processor, model = tamil_processor, tamil_model
26
+ elif detected_lang == "si":
27
+ processor, model = sinhala_processor, sinhala_model
28
+ else:
29
+ model = english_model
30
+ return model.transcribe(audio_file)["text"]
31
+
32
+ audio = whisper.load_audio(audio_file)
33
+ inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
34
+ with torch.no_grad():
35
+ predicted_ids = model.generate(**inputs)
36
+ return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
37
+
38
+ def translate_to_english(text):
39
+ return Translator().translate(text, dest="en").text
40
+
41
+ def full_pipeline(audio_file):
42
+ lang_detector, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model = load_models()
43
+ detected_lang = detect_language(audio_file, lang_detector)
44
+ transcription = transcribe_audio(audio_file, detected_lang, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model)
45
+ return translate_to_english(transcription)