File size: 2,206 Bytes
85dbb76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

import whisper
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from googletrans import Translator
import torch

def load_models():
    lang_detector = whisper.load_model("small")
    tamil_processor = WhisperProcessor.from_pretrained("Lingalingeswaran/whisper-small-ta")
    tamil_model = WhisperForConditionalGeneration.from_pretrained("Lingalingeswaran/whisper-small-ta")
    sinhala_processor = WhisperProcessor.from_pretrained("Lingalingeswaran/whisper-small-sinhala")
    sinhala_model = WhisperForConditionalGeneration.from_pretrained("Lingalingeswaran/whisper-small-sinhala")
    english_model = whisper.load_model("small")
    return lang_detector, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model

def detect_language(audio_file, lang_detector):
    audio = whisper.load_audio(audio_file)
    audio = whisper.pad_or_trim(audio)
    mel = whisper.log_mel_spectrogram(audio).to(lang_detector.device)
    _, probs = lang_detector.detect_language(mel)
    return max(probs, key=probs.get)

def transcribe_audio(audio_file, detected_lang, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model):
    if detected_lang == "ta":
        processor, model = tamil_processor, tamil_model
    elif detected_lang == "si":
        processor, model = sinhala_processor, sinhala_model
    else:
        model = english_model
        return model.transcribe(audio_file)["text"]

    audio = whisper.load_audio(audio_file)
    inputs = processor(audio, return_tensors="pt", sampling_rate=16000)
    with torch.no_grad():
        predicted_ids = model.generate(**inputs)
    return processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]

def translate_to_english(text):
    return Translator().translate(text, dest="en").text

def full_pipeline(audio_file):
    lang_detector, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model = load_models()
    detected_lang = detect_language(audio_file, lang_detector)
    transcription = transcribe_audio(audio_file, detected_lang, tamil_processor, tamil_model, sinhala_processor, sinhala_model, english_model)
    return translate_to_english(transcription)