Banafo's picture
Update app.py
3974267 verified
import numpy as np
import gradio as gr
import torchaudio
import torch
from sherpa_onnx import OnlineRecognizer
import time
# Initialize the recognizer
recognizer_en = OnlineRecognizer.from_transducer(
tokens="en_tokens.txt",
encoder="en_encoder.onnx",
decoder="en_decoder.onnx",
joiner="en_joiner.onnx",
num_threads=1,
decoding_method="modified_beam_search",
debug=False
)
recognizer_fr = OnlineRecognizer.from_transducer(
tokens="fr_tokens.txt",
encoder="fr_encoder.onnx",
decoder="fr_decoder.onnx",
joiner="fr_joiner.onnx",
num_threads=1,
decoding_method="modified_beam_search",
debug=False
)
recognizer_de = OnlineRecognizer.from_transducer(
tokens="de_tokens.txt",
encoder="de_encoder.onnx",
decoder="de_decoder.onnx",
joiner="de_joiner.onnx",
num_threads=1,
decoding_method="modified_beam_search",
debug=False
)
def transcribe_audio_online_streaming(file, language):
"""Generator for file transcription"""
if file is None:
yield "Please upload an audio file."
return
try:
match language:
case "English":
recognizer = recognizer_en
case "French":
recognizer = recognizer_fr
case "German":
recognizer = recognizer_de
waveform, sample_rate = torchaudio.load(file.name)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
sample_rate = 16000
waveform_np = waveform.numpy()[0]
# Add 0.5 seconds of silence padding at the beginning and end
pad_duration = 0.5 # seconds
pad_samples = int(pad_duration * sample_rate)
pad_start = np.zeros(pad_samples, dtype=np.float32)
pad_end = np.zeros(pad_samples, dtype=np.float32)
waveform_np = np.concatenate([pad_start, waveform_np, pad_end])
total_samples = waveform_np.shape[0]
s = recognizer.create_stream()
chunk_size = 4000 # 0.25-second chunks
offset = 0
while offset < total_samples:
end = offset + chunk_size
chunk = waveform_np[offset:end]
s.accept_waveform(sample_rate, chunk.tolist())
while recognizer.is_ready(s):
recognizer.decode_streams([s])
yield recognizer.get_result(s)
offset += chunk_size
# Final processing
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32)
s.accept_waveform(sample_rate, tail_paddings.tolist())
s.input_finished()
while recognizer.is_ready(s):
recognizer.decode_streams([s])
current_text = recognizer.get_result(s)
if isinstance(current_text, (list, np.ndarray)):
current_text = " ".join(map(str, current_text))
elif isinstance(current_text, bytes):
current_text = current_text.decode("utf-8", errors="ignore")
yield current_text
except Exception as e:
yield f"Error: {e}"
def transcribe_microphone_stream(audio_chunk, stream_state, language):
"""Real-time microphone streaming transcription"""
try:
match language:
case "English":
recognizer = recognizer_en
case "French":
recognizer = recognizer_fr
case "German":
recognizer = recognizer_de
if audio_chunk is None: # End of stream
if stream_state is not None:
# Flush remaining audio
tail_paddings = np.zeros(int(0.66 * 16000), dtype=np.float32)
stream_state.accept_waveform(16000, tail_paddings.tolist())
stream_state.input_finished()
while recognizer.is_ready(stream_state):
recognizer.decode_streams([stream_state])
final_text = recognizer.get_result(stream_state)
return final_text, None
return "", None
sample_rate, waveform_np = audio_chunk
if len(waveform_np.shape) > 1:
waveform_np = waveform_np.mean(axis=1)
# Resample if needed
if sample_rate != 16000:
waveform = torch.from_numpy(waveform_np).float().unsqueeze(0)
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
waveform = resampler(waveform)
waveform_np = waveform.squeeze(0).numpy()
sample_rate = 16000
# Initialize stream if first chunk
if stream_state is None:
stream_state = recognizer.create_stream()
# Process audio chunk
stream_state.accept_waveform(sample_rate, waveform_np.tolist())
# Decode available frames
while recognizer.is_ready(stream_state):
recognizer.decode_streams([stream_state])
current_text = recognizer.get_result(stream_state)
if isinstance(current_text, (list, np.ndarray)):
current_text = " ".join(map(str, current_text))
elif isinstance(current_text, bytes):
current_text = current_text.decode("utf-8", errors="ignore")
return current_text, stream_state
except Exception as e:
print(f"Stream error: {e}")
return str(e), stream_state
def create_app():
with gr.Blocks() as app:
gr.Markdown("# Real-time Speech Recognition")
language_choice = gr.Radio(choices=["English", "French", "German"], label="Select Language", value="English")
with gr.Tabs():
with gr.Tab("File Transcription"):
gr.Markdown("Upload an audio file for streaming transcription")
file_input = gr.File(label="Audio File", type="filepath")
file_output = gr.Textbox(label="Transcription")
transcribe_btn = gr.Button("Transcribe")
transcribe_btn.click(lambda: "", outputs=file_output).then(
transcribe_audio_online_streaming,
inputs=[file_input, language_choice],
outputs=file_output
)
with gr.Tab("Live Microphone"):
gr.Markdown("Speak into your microphone for real-time transcription")
mic = gr.Audio(
sources=["microphone"],
streaming=True,
type="numpy",
label="Live Input",
show_download_button=False
)
live_output = gr.Textbox(label="Live Transcription")
state = gr.State()
mic.stream(
transcribe_microphone_stream,
inputs=[mic, state, language_choice],
outputs=[live_output, state],
show_progress="hidden"
)
return app
if __name__ == "__main__":
app = create_app()
app.launch()