import numpy as np import gradio as gr import torchaudio import torch from sherpa_onnx import OnlineRecognizer import time # Initialize the recognizer recognizer_en = OnlineRecognizer.from_transducer( tokens="en_tokens.txt", encoder="en_encoder.onnx", decoder="en_decoder.onnx", joiner="en_joiner.onnx", num_threads=1, decoding_method="modified_beam_search", debug=False ) recognizer_fr = OnlineRecognizer.from_transducer( tokens="fr_tokens.txt", encoder="fr_encoder.onnx", decoder="fr_decoder.onnx", joiner="fr_joiner.onnx", num_threads=1, decoding_method="modified_beam_search", debug=False ) recognizer_de = OnlineRecognizer.from_transducer( tokens="de_tokens.txt", encoder="de_encoder.onnx", decoder="de_decoder.onnx", joiner="de_joiner.onnx", num_threads=1, decoding_method="modified_beam_search", debug=False ) def transcribe_audio_online_streaming(file, language): """Generator for file transcription""" if file is None: yield "Please upload an audio file." return try: match language: case "English": recognizer = recognizer_en case "French": recognizer = recognizer_fr case "German": recognizer = recognizer_de waveform, sample_rate = torchaudio.load(file.name) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) sample_rate = 16000 waveform_np = waveform.numpy()[0] # Add 0.5 seconds of silence padding at the beginning and end pad_duration = 0.5 # seconds pad_samples = int(pad_duration * sample_rate) pad_start = np.zeros(pad_samples, dtype=np.float32) pad_end = np.zeros(pad_samples, dtype=np.float32) waveform_np = np.concatenate([pad_start, waveform_np, pad_end]) total_samples = waveform_np.shape[0] s = recognizer.create_stream() chunk_size = 4000 # 0.25-second chunks offset = 0 while offset < total_samples: end = offset + chunk_size chunk = waveform_np[offset:end] s.accept_waveform(sample_rate, chunk.tolist()) while recognizer.is_ready(s): recognizer.decode_streams([s]) yield recognizer.get_result(s) offset += chunk_size # Final processing tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) s.accept_waveform(sample_rate, tail_paddings.tolist()) s.input_finished() while recognizer.is_ready(s): recognizer.decode_streams([s]) current_text = recognizer.get_result(s) if isinstance(current_text, (list, np.ndarray)): current_text = " ".join(map(str, current_text)) elif isinstance(current_text, bytes): current_text = current_text.decode("utf-8", errors="ignore") yield current_text except Exception as e: yield f"Error: {e}" def transcribe_microphone_stream(audio_chunk, stream_state, language): """Real-time microphone streaming transcription""" try: match language: case "English": recognizer = recognizer_en case "French": recognizer = recognizer_fr case "German": recognizer = recognizer_de if audio_chunk is None: # End of stream if stream_state is not None: # Flush remaining audio tail_paddings = np.zeros(int(0.66 * 16000), dtype=np.float32) stream_state.accept_waveform(16000, tail_paddings.tolist()) stream_state.input_finished() while recognizer.is_ready(stream_state): recognizer.decode_streams([stream_state]) final_text = recognizer.get_result(stream_state) return final_text, None return "", None sample_rate, waveform_np = audio_chunk if len(waveform_np.shape) > 1: waveform_np = waveform_np.mean(axis=1) # Resample if needed if sample_rate != 16000: waveform = torch.from_numpy(waveform_np).float().unsqueeze(0) resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) waveform = resampler(waveform) waveform_np = waveform.squeeze(0).numpy() sample_rate = 16000 # Initialize stream if first chunk if stream_state is None: stream_state = recognizer.create_stream() # Process audio chunk stream_state.accept_waveform(sample_rate, waveform_np.tolist()) # Decode available frames while recognizer.is_ready(stream_state): recognizer.decode_streams([stream_state]) current_text = recognizer.get_result(stream_state) if isinstance(current_text, (list, np.ndarray)): current_text = " ".join(map(str, current_text)) elif isinstance(current_text, bytes): current_text = current_text.decode("utf-8", errors="ignore") return current_text, stream_state except Exception as e: print(f"Stream error: {e}") return str(e), stream_state def create_app(): with gr.Blocks() as app: gr.Markdown("# Real-time Speech Recognition") language_choice = gr.Radio(choices=["English", "French", "German"], label="Select Language", value="English") with gr.Tabs(): with gr.Tab("File Transcription"): gr.Markdown("Upload an audio file for streaming transcription") file_input = gr.File(label="Audio File", type="filepath") file_output = gr.Textbox(label="Transcription") transcribe_btn = gr.Button("Transcribe") transcribe_btn.click(lambda: "", outputs=file_output).then( transcribe_audio_online_streaming, inputs=[file_input, language_choice], outputs=file_output ) with gr.Tab("Live Microphone"): gr.Markdown("Speak into your microphone for real-time transcription") mic = gr.Audio( sources=["microphone"], streaming=True, type="numpy", label="Live Input", show_download_button=False ) live_output = gr.Textbox(label="Live Transcription") state = gr.State() mic.stream( transcribe_microphone_stream, inputs=[mic, state, language_choice], outputs=[live_output, state], show_progress="hidden" ) return app if __name__ == "__main__": app = create_app() app.launch()