wedyanessam commited on
Commit
ef2ca90
·
verified ·
1 Parent(s): 3ec929e

Update STT/sst.py

Browse files
Files changed (1) hide show
  1. STT/sst.py +15 -16
STT/sst.py CHANGED
@@ -2,30 +2,30 @@ from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
2
  import torchaudio
3
  import torch
4
 
5
- # تحميل المعالج والنموذج
6
- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
7
- model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
8
 
9
  def speech_to_text(audio_path):
10
  if audio_path is None:
11
- raise ValueError("Audio path is None. Did you upload a file?")
12
-
13
- # تحميل الصوت
14
- waveform, sampling_rate = torchaudio.load(audio_path)
15
 
16
- # إذا كان ستيريو نخليه mono
 
 
 
17
  if waveform.shape[0] > 1:
18
- waveform = waveform.mean(dim=0)
19
 
20
- # إعادة تشكيل الصوت إذا كان غير 16kHz
21
- if sampling_rate != 16000:
22
- resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)
23
  waveform = resampler(waveform)
24
-
25
- # تجهيز البيانات للنموذج
26
  input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
27
 
28
- # استنتاج الـ logits والتنبؤ
29
  with torch.no_grad():
30
  logits = model(input_values).logits
31
 
@@ -35,4 +35,3 @@ def speech_to_text(audio_path):
35
  transcription = processor.batch_decode(predicted_ids)
36
 
37
  return transcription[0]
38
-
 
2
  import torchaudio
3
  import torch
4
 
5
+ # تحميل المعالج والموديل العربي
6
+ processor = Wav2Vec2Processor.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
7
+ model = Wav2Vec2ForCTC.from_pretrained("jonatasgrosman/wav2vec2-large-xlsr-53-arabic")
8
 
9
  def speech_to_text(audio_path):
10
  if audio_path is None:
11
+ raise ValueError("الصوت غير موجود")
 
 
 
12
 
13
+ # تحميل الملف الصوتي
14
+ waveform, sample_rate = torchaudio.load(audio_path)
15
+
16
+ # إذا الصوت ستيريو نحوله لمونو
17
  if waveform.shape[0] > 1:
18
+ waveform = waveform.mean(dim=0).unsqueeze(0)
19
 
20
+ # إعادة تحويل التردد إلى 16000 لو كان مختلف
21
+ if sample_rate != 16000:
22
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
23
  waveform = resampler(waveform)
24
+
25
+ # تجهيز الإدخال للنموذج
26
  input_values = processor(waveform.squeeze().numpy(), return_tensors="pt", sampling_rate=16000).input_values
27
 
28
+ # تمرير البيانات للنموذج والحصول على النتائج
29
  with torch.no_grad():
30
  logits = model(input_values).logits
31
 
 
35
  transcription = processor.batch_decode(predicted_ids)
36
 
37
  return transcription[0]