RajatMalviya commited on
Commit
40b3e9c
·
verified ·
1 Parent(s): 58affcb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -11
app.py CHANGED
@@ -1,16 +1,19 @@
1
  import streamlit as st
2
  import tempfile
3
  import os
4
- from transformers import pipeline
 
 
5
 
6
- # Load the ASR model
7
  @st.cache_resource
8
  def load_model():
9
- return pipeline("automatic-speech-recognition", model="ivrit-ai/whisper-large-v3-turbo")
 
 
10
 
11
- model = load_model()
12
 
13
- # Streamlit UI
14
  st.title("Hebrew Speech-to-Text Transcription")
15
 
16
  # Upload audio file
@@ -22,14 +25,24 @@ if uploaded_file is not None:
22
  temp_audio.write(uploaded_file.read())
23
  temp_audio_path = temp_audio.name
24
 
25
- # Transcribe the audio
26
- st.write("Transcribing...")
27
  try:
28
- result = model(temp_audio_path)
 
 
 
 
 
 
 
 
 
 
 
29
  st.subheader("Transcription:")
30
- st.write(result["text"])
 
31
  except Exception as e:
32
  st.error(f"Error: {str(e)}")
33
-
34
  # Clean up the temporary file
35
- os.remove(temp_audio_path)
 
1
  import streamlit as st
2
  import tempfile
3
  import os
4
+ import librosa # For audio resampling
5
+ import torch
6
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
 
8
+ # Load the model and processor
9
  @st.cache_resource
10
  def load_model():
11
+ processor = WhisperProcessor.from_pretrained("ivrit-ai/whisper-large-v3-turbo")
12
+ model = WhisperForConditionalGeneration.from_pretrained("ivrit-ai/whisper-large-v3-turbo")
13
+ return processor, model
14
 
15
+ processor, model = load_model()
16
 
 
17
  st.title("Hebrew Speech-to-Text Transcription")
18
 
19
  # Upload audio file
 
25
  temp_audio.write(uploaded_file.read())
26
  temp_audio_path = temp_audio.name
27
 
 
 
28
  try:
29
+ # Load and resample audio to 16kHz (required by Whisper)
30
+ speech_array, sampling_rate = librosa.load(temp_audio_path, sr=16000)
31
+
32
+ # Preprocess audio
33
+ inputs = processor(speech_array, sampling_rate=16000, return_tensors="pt")
34
+
35
+ # Generate transcription
36
+ with torch.no_grad():
37
+ predicted_ids = model.generate(inputs.input_features)
38
+
39
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
40
+
41
  st.subheader("Transcription:")
42
+ st.write(transcription)
43
+
44
  except Exception as e:
45
  st.error(f"Error: {str(e)}")
46
+
47
  # Clean up the temporary file
48
+ os.remove(temp_audio_path)