chinmayc3 commited on
Commit
25e924b
Β·
1 Parent(s): 48732e0

Added ability to upload model and select model option

Browse files
Files changed (5) hide show
  1. app.py +132 -6
  2. enums.py +3 -0
  3. logger.py +3 -0
  4. requirements.txt +2 -1
  5. utils.py +3 -0
app.py CHANGED
@@ -15,15 +15,18 @@ import requests
15
  import streamlit as st
16
  from audio_recorder_streamlit import audio_recorder
17
  import torchaudio
 
18
 
19
  from logger import logger
20
  from utils import fs
21
  from enums import SAVE_PATH, ELO_JSON_PATH, ELO_CSV_PATH, EMAIL_PATH, TEMP_DIR, NEW_TASK_URL,ARENA_PATH
22
 
 
23
  result_queue = Queue()
24
  random_df = pd.read_csv("random_audios.csv")
25
  random_paths = random_df["path"].tolist()
26
 
 
27
  def result_writer_thread():
28
  result_writer = ResultWriter(SAVE_PATH)
29
  while True:
@@ -140,9 +143,9 @@ def send_task(payload):
140
  "Authorization": f"Bearer {os.getenv('CREATE_TASK_API_KEY')}"
141
  }
142
  if payload["task"] in ["fetch_audio","write_result"]:
143
- response = requests.post(NEW_TASK_URL,json=payload,headers=header,timeout=300)
144
  else:
145
- response = requests.post(NEW_TASK_URL,json=payload,headers=header,timeout=300,stream=True)
146
  try:
147
  response = response.json()
148
  except Exception as e:
@@ -172,6 +175,35 @@ def encode_audio_array(audio_array):
172
 
173
  return base64_string
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  def call_function(model_name):
176
  if st.session_state.current_audio_type == "recorded":
177
  y,_ = librosa.load(st.session_state.audio_path,sr=22050,mono=True)
@@ -183,6 +215,20 @@ def call_function(model_name):
183
  "model_name":model_name,
184
  "audio_b64":True
185
  }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  else:
187
  sr = st.session_state.audio['sample_rate']
188
  array = st.session_state.audio['data']
@@ -202,7 +248,20 @@ def call_function(model_name):
202
 
203
  def transcribe_audio():
204
  models_list = ["Ori Apex", "Ori Apex XT", "deepgram", "Ori Swift", "Ori Prime","azure"]
205
- model1_name, model2_name = random.sample(models_list, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  st.session_state.option_1_model_name = model1_name
208
  st.session_state.option_2_model_name = model2_name
@@ -345,7 +404,7 @@ writer_thread = threading.Thread(target=result_writer_thread)
345
  writer_thread.start()
346
 
347
  def main():
348
-
349
  st.title("βš”οΈ Ori Speech-To-Text Arena βš”οΈ")
350
 
351
  if "has_audio" not in st.session_state:
@@ -374,7 +433,12 @@ def main():
374
  st.session_state.recording = True
375
  if "disable_voting" not in st.session_state:
376
  st.session_state.disable_voting = True
377
- col1, col2 = st.columns([1, 1])
 
 
 
 
 
378
 
379
  with col1:
380
  st.markdown("### Record Audio")
@@ -406,9 +470,69 @@ def main():
406
  st.button("🎲 Select Random Audio",on_click=on_random_click,key="random_btn")
407
  st.session_state.recording = False
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  if st.session_state.has_audio:
410
  st.audio(**st.session_state.audio)
411
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
 
413
  with st.container():
414
  st.button("πŸ“ Transcribe Audio",on_click=on_click_transcribe,use_container_width=True,key="transcribe_btn",disabled=st.session_state.recording)
@@ -449,7 +573,8 @@ def main():
449
 
450
  INSTR = """
451
  ## Instructions:
452
- * Record audio to recognise speech (or press 🎲 for random Audio).
 
453
  * Click on transcribe audio button to commence the transcription process.
454
  * Read the two options one after the other while listening to the audio.
455
  * Vote on which transcript you prefer.
@@ -458,6 +583,7 @@ def main():
458
  * Currently Hindi and English are supported, and
459
  the results for Hindi will be in Hinglish (Hindi in Latin script)
460
  * It may take up to 30 seconds for speech recognition in some cases.
 
461
  """.strip()
462
 
463
  st.markdown(INSTR)
 
15
  import streamlit as st
16
  from audio_recorder_streamlit import audio_recorder
17
  import torchaudio
18
+ from dotenv import load_dotenv
19
 
20
  from logger import logger
21
  from utils import fs
22
  from enums import SAVE_PATH, ELO_JSON_PATH, ELO_CSV_PATH, EMAIL_PATH, TEMP_DIR, NEW_TASK_URL,ARENA_PATH
23
 
24
+ load_dotenv()
25
  result_queue = Queue()
26
  random_df = pd.read_csv("random_audios.csv")
27
  random_paths = random_df["path"].tolist()
28
 
29
+
30
  def result_writer_thread():
31
  result_writer = ResultWriter(SAVE_PATH)
32
  while True:
 
143
  "Authorization": f"Bearer {os.getenv('CREATE_TASK_API_KEY')}"
144
  }
145
  if payload["task"] in ["fetch_audio","write_result"]:
146
+ response = requests.post(NEW_TASK_URL,json=payload,headers=header,timeout=600)
147
  else:
148
+ response = requests.post(NEW_TASK_URL,json=payload,headers=header,timeout=600,stream=True)
149
  try:
150
  response = response.json()
151
  except Exception as e:
 
175
 
176
  return base64_string
177
 
178
+ def validate_uploaded_audio(uploaded_file):
179
+ """
180
+ Validate uploaded audio file format and duration
181
+ Returns: (is_valid, error_message, audio_data, sample_rate)
182
+ """
183
+ allowed_extensions = ['.wav', '.mp3', '.flac']
184
+ file_extension = os.path.splitext(uploaded_file.name)[1].lower()
185
+
186
+ if file_extension not in allowed_extensions:
187
+ return False, f"Unsupported file format. Please upload {', '.join(allowed_extensions)} files only.", None, None
188
+
189
+ try:
190
+ audio_bytes = uploaded_file.read()
191
+
192
+ with tempfile.NamedTemporaryFile(delete=True, suffix=file_extension) as tmp_file:
193
+ tmp_file.write(audio_bytes)
194
+ temp_path = tmp_file.name
195
+
196
+ audio_data, sample_rate = librosa.load(temp_path, sr=None)
197
+ duration = len(audio_data) / sample_rate
198
+
199
+ if duration > 30:
200
+ return False, f"Audio duration ({duration:.1f}s) exceeds the 30-second limit. Please upload shorter audio.", None, None
201
+
202
+ return True, None, audio_data, sample_rate
203
+
204
+ except Exception as e:
205
+ return False, f"Error processing audio file: {str(e)}", None, None
206
+
207
  def call_function(model_name):
208
  if st.session_state.current_audio_type == "recorded":
209
  y,_ = librosa.load(st.session_state.audio_path,sr=22050,mono=True)
 
215
  "model_name":model_name,
216
  "audio_b64":True
217
  }}
218
+ elif st.session_state.current_audio_type == "uploaded":
219
+ # For uploaded files, use the processed audio data
220
+ array = st.session_state.audio['data']
221
+ sr = st.session_state.audio['sample_rate']
222
+ if sr != 22050:
223
+ array = librosa.resample(y=array, orig_sr=sr, target_sr=22050)
224
+ encoded_array = encode_audio_array(array)
225
+ payload = {
226
+ "task":"transcribe_with_fastapi",
227
+ "payload":{
228
+ "file_path":encoded_array,
229
+ "model_name":model_name,
230
+ "audio_b64":True
231
+ }}
232
  else:
233
  sr = st.session_state.audio['sample_rate']
234
  array = st.session_state.audio['data']
 
248
 
249
  def transcribe_audio():
250
  models_list = ["Ori Apex", "Ori Apex XT", "deepgram", "Ori Swift", "Ori Prime","azure"]
251
+
252
+ if st.session_state.model_1_selection == "Random":
253
+ model1_name = random.choice(models_list)
254
+ else:
255
+ model1_name = st.session_state.model_1_selection
256
+
257
+ if st.session_state.model_2_selection == "Random":
258
+ if st.session_state.model_1_selection == "Random":
259
+ available_models = [m for m in models_list if m != model1_name]
260
+ model2_name = random.choice(available_models)
261
+ else:
262
+ model2_name = random.choice(models_list)
263
+ else:
264
+ model2_name = st.session_state.model_2_selection
265
 
266
  st.session_state.option_1_model_name = model1_name
267
  st.session_state.option_2_model_name = model2_name
 
404
  writer_thread.start()
405
 
406
  def main():
407
+ st.set_page_config(layout="wide",initial_sidebar_state="collapsed")
408
  st.title("βš”οΈ Ori Speech-To-Text Arena βš”οΈ")
409
 
410
  if "has_audio" not in st.session_state:
 
433
  st.session_state.recording = True
434
  if "disable_voting" not in st.session_state:
435
  st.session_state.disable_voting = True
436
+ if "model_1_selection" not in st.session_state:
437
+ st.session_state.model_1_selection = "Random"
438
+ if "model_2_selection" not in st.session_state:
439
+ st.session_state.model_2_selection = "Random"
440
+
441
+ col1, col2, col3 = st.columns([1, 1, 1])
442
 
443
  with col1:
444
  st.markdown("### Record Audio")
 
470
  st.button("🎲 Select Random Audio",on_click=on_random_click,key="random_btn")
471
  st.session_state.recording = False
472
 
473
+ with col3:
474
+ st.markdown("### Upload Audio File")
475
+ with st.container():
476
+ uploaded_file = st.file_uploader(
477
+ "Choose an audio file",
478
+ type=['wav', 'mp3', 'flac'],
479
+ key="audio_uploader",
480
+ help="Upload .wav, .mp3, or .flac files (max 30 seconds)"
481
+ )
482
+
483
+ if uploaded_file is not None:
484
+ if uploaded_file != st.session_state.get('last_uploaded_file'):
485
+ st.session_state.last_uploaded_file = uploaded_file
486
+
487
+ with st.spinner("Processing uploaded audio..."):
488
+ is_valid, error_msg, audio_data, sample_rate = validate_uploaded_audio(uploaded_file)
489
+
490
+ if is_valid:
491
+ reset_state()
492
+
493
+ with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(uploaded_file.name)[1]) as tmp_file:
494
+ tmp_file.write(uploaded_file.getvalue())
495
+ temp_path = tmp_file.name
496
+
497
+ st.session_state.audio = {
498
+ "data": audio_data,
499
+ "sample_rate": sample_rate,
500
+ "format": "audio/wav"
501
+ }
502
+ st.session_state.current_audio_type = "uploaded"
503
+ st.session_state.has_audio = True
504
+ st.session_state.audio_path = temp_path
505
+ st.session_state.option_selected = None
506
+ st.session_state.recording = False
507
+
508
+ duration = len(audio_data) / sample_rate
509
+ st.success(f"βœ… Audio uploaded successfully! Duration: {duration:.1f}s")
510
+ else:
511
+ st.error(f"❌ {error_msg}")
512
+
513
  if st.session_state.has_audio:
514
  st.audio(**st.session_state.audio)
515
 
516
+ st.markdown("### Model Selection")
517
+ col_model1, col_model2 = st.columns(2)
518
+
519
+ models_list = ["Random", "Ori Apex", "Ori Apex XT", "deepgram", "Ori Swift", "Ori Prime", "azure"]
520
+
521
+ with col_model1:
522
+ st.selectbox(
523
+ "Model 1:",
524
+ options=models_list,
525
+ index=0,
526
+ key="model_1_selection"
527
+ )
528
+
529
+ with col_model2:
530
+ st.selectbox(
531
+ "Model 2:",
532
+ options=models_list,
533
+ index=0,
534
+ key="model_2_selection"
535
+ )
536
 
537
  with st.container():
538
  st.button("πŸ“ Transcribe Audio",on_click=on_click_transcribe,use_container_width=True,key="transcribe_btn",disabled=st.session_state.recording)
 
573
 
574
  INSTR = """
575
  ## Instructions:
576
+ * Record audio to recognise speech, upload an audio file, or press 🎲 for random Audio.
577
+ * Optionally select specific models using the Model 1 and Model 2 dropdowns (default is Random).
578
  * Click on transcribe audio button to commence the transcription process.
579
  * Read the two options one after the other while listening to the audio.
580
  * Vote on which transcript you prefer.
 
583
  * Currently Hindi and English are supported, and
584
  the results for Hindi will be in Hinglish (Hindi in Latin script)
585
  * It may take up to 30 seconds for speech recognition in some cases.
586
+ * Uploaded audio files must be .wav, .mp3, or .flac format and under 30 seconds duration.
587
  """.strip()
588
 
589
  st.markdown(INSTR)
enums.py CHANGED
@@ -1,4 +1,7 @@
1
  import os
 
 
 
2
 
3
  SAVE_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('RESULTS_KEY')}"
4
  ELO_JSON_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('ELO_JSON_PATH')}"
 
1
  import os
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
 
6
  SAVE_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('RESULTS_KEY')}"
7
  ELO_JSON_PATH = f"s3://{os.getenv('AWS_BUCKET_NAME')}/{os.getenv('ELO_JSON_PATH')}"
logger.py CHANGED
@@ -1,5 +1,8 @@
1
  import logging
2
  import os
 
 
 
3
 
4
  loglevel = os.getenv("LOGLEVEL", "INFO")
5
 
 
1
  import logging
2
  import os
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
 
7
  loglevel = os.getenv("LOGLEVEL", "INFO")
8
 
requirements.txt CHANGED
@@ -9,4 +9,5 @@ streamlit==1.40.2
9
  fsspec==2024.10.0
10
  boto3
11
  s3fs
12
- torchaudio
 
 
9
  fsspec==2024.10.0
10
  boto3
11
  s3fs
12
+ torchaudio
13
+ python-dotenv
utils.py CHANGED
@@ -2,6 +2,9 @@ import fsspec
2
  import boto3
3
  import os
4
  import re
 
 
 
5
 
6
  fs = fsspec.filesystem(
7
  's3',
 
2
  import boto3
3
  import os
4
  import re
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
 
9
  fs = fsspec.filesystem(
10
  's3',