wedyanessam's picture
Update STT/sst.py
ef2ca90 verified
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torchaudio
import torch
# ุชุญู…ูŠู„ ุงู„ู…ุนุงู„ุฌ ูˆุงู„ู…ูˆุฏูŠู„ ุงู„ุนุฑุจูŠ
processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
def speech_to_text(audio_path):
if audio_path is None:
raise ValueError("ุงู„ุตูˆุช ุบูŠุฑ ู…ูˆุฌูˆุฏ")
# ุชุญู…ูŠู„ ุงู„ู…ู„ู ุงู„ุตูˆุชูŠ
waveform, sample_rate = torchaudio.load(audio_path)
# ุฅุฐุง ุงู„ุตูˆุช ุณุชูŠุฑูŠูˆ ู†ุญูˆู„ู‡ ู„ู…ูˆู†ูˆ
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0).unsqueeze(0)
# ุฅุนุงุฏุฉ ุชุญูˆูŠู„ ุงู„ุชุฑุฏุฏ ุฅู„ู‰ 16000 ู„ูˆ ูƒุงู† ู…ุฎุชู„ู
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
# ุชุฌู‡ูŠุฒ ุงู„ุฅุฏุฎุงู„ ู„ู„ู†ู…ูˆุฐุฌ
input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
# ุชู…ุฑูŠุฑ ุงู„ุจูŠุงู†ุงุช ู„ู„ู†ู…ูˆุฐุฌ ูˆุงู„ุญุตูˆู„ ุนู„ู‰ ุงู„ู†ุชุงุฆุฌ
with torch.no_grad():
logits = model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
# ุชุญูˆูŠู„ ุงู„ุชู†ุจุค ุฅู„ู‰ ู†ุต
transcription = processor.batch_decode(predicted_ids)
return transcription[0]