|
import numpy as np |
|
import gradio as gr |
|
import torchaudio |
|
import torch |
|
from sherpa_onnx import OnlineRecognizer |
|
import time |
|
|
|
|
|
recognizer_en = OnlineRecognizer.from_transducer( |
|
tokens="en_tokens.txt", |
|
encoder="en_encoder.onnx", |
|
decoder="en_decoder.onnx", |
|
joiner="en_joiner.onnx", |
|
num_threads=1, |
|
decoding_method="modified_beam_search", |
|
debug=False |
|
) |
|
|
|
recognizer_fr = OnlineRecognizer.from_transducer( |
|
tokens="fr_tokens.txt", |
|
encoder="fr_encoder.onnx", |
|
decoder="fr_decoder.onnx", |
|
joiner="fr_joiner.onnx", |
|
num_threads=1, |
|
decoding_method="modified_beam_search", |
|
debug=False |
|
) |
|
|
|
recognizer_de = OnlineRecognizer.from_transducer( |
|
tokens="de_tokens.txt", |
|
encoder="de_encoder.onnx", |
|
decoder="de_decoder.onnx", |
|
joiner="de_joiner.onnx", |
|
num_threads=1, |
|
decoding_method="modified_beam_search", |
|
debug=False |
|
) |
|
|
|
def transcribe_audio_online_streaming(file, language): |
|
"""Generator for file transcription""" |
|
if file is None: |
|
yield "Please upload an audio file." |
|
return |
|
|
|
try: |
|
match language: |
|
case "English": |
|
recognizer = recognizer_en |
|
case "French": |
|
recognizer = recognizer_fr |
|
case "German": |
|
recognizer = recognizer_de |
|
|
|
waveform, sample_rate = torchaudio.load(file.name) |
|
if sample_rate != 16000: |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
waveform = resampler(waveform) |
|
sample_rate = 16000 |
|
|
|
waveform_np = waveform.numpy()[0] |
|
|
|
|
|
pad_duration = 0.5 |
|
pad_samples = int(pad_duration * sample_rate) |
|
pad_start = np.zeros(pad_samples, dtype=np.float32) |
|
pad_end = np.zeros(pad_samples, dtype=np.float32) |
|
waveform_np = np.concatenate([pad_start, waveform_np, pad_end]) |
|
|
|
total_samples = waveform_np.shape[0] |
|
|
|
s = recognizer.create_stream() |
|
chunk_size = 4000 |
|
offset = 0 |
|
|
|
while offset < total_samples: |
|
end = offset + chunk_size |
|
chunk = waveform_np[offset:end] |
|
s.accept_waveform(sample_rate, chunk.tolist()) |
|
|
|
while recognizer.is_ready(s): |
|
recognizer.decode_streams([s]) |
|
|
|
yield recognizer.get_result(s) |
|
offset += chunk_size |
|
|
|
|
|
tail_paddings = np.zeros(int(0.66 * sample_rate), dtype=np.float32) |
|
s.accept_waveform(sample_rate, tail_paddings.tolist()) |
|
s.input_finished() |
|
|
|
while recognizer.is_ready(s): |
|
recognizer.decode_streams([s]) |
|
|
|
current_text = recognizer.get_result(s) |
|
if isinstance(current_text, (list, np.ndarray)): |
|
current_text = " ".join(map(str, current_text)) |
|
elif isinstance(current_text, bytes): |
|
current_text = current_text.decode("utf-8", errors="ignore") |
|
|
|
yield current_text |
|
|
|
except Exception as e: |
|
yield f"Error: {e}" |
|
|
|
def transcribe_microphone_stream(audio_chunk, stream_state, language): |
|
"""Real-time microphone streaming transcription""" |
|
try: |
|
match language: |
|
case "English": |
|
recognizer = recognizer_en |
|
case "French": |
|
recognizer = recognizer_fr |
|
case "German": |
|
recognizer = recognizer_de |
|
|
|
if audio_chunk is None: |
|
if stream_state is not None: |
|
|
|
tail_paddings = np.zeros(int(0.66 * 16000), dtype=np.float32) |
|
stream_state.accept_waveform(16000, tail_paddings.tolist()) |
|
stream_state.input_finished() |
|
while recognizer.is_ready(stream_state): |
|
recognizer.decode_streams([stream_state]) |
|
final_text = recognizer.get_result(stream_state) |
|
return final_text, None |
|
return "", None |
|
|
|
sample_rate, waveform_np = audio_chunk |
|
if len(waveform_np.shape) > 1: |
|
waveform_np = waveform_np.mean(axis=1) |
|
|
|
|
|
if sample_rate != 16000: |
|
waveform = torch.from_numpy(waveform_np).float().unsqueeze(0) |
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
waveform = resampler(waveform) |
|
waveform_np = waveform.squeeze(0).numpy() |
|
sample_rate = 16000 |
|
|
|
|
|
if stream_state is None: |
|
stream_state = recognizer.create_stream() |
|
|
|
|
|
stream_state.accept_waveform(sample_rate, waveform_np.tolist()) |
|
|
|
|
|
while recognizer.is_ready(stream_state): |
|
recognizer.decode_streams([stream_state]) |
|
|
|
current_text = recognizer.get_result(stream_state) |
|
|
|
if isinstance(current_text, (list, np.ndarray)): |
|
current_text = " ".join(map(str, current_text)) |
|
elif isinstance(current_text, bytes): |
|
current_text = current_text.decode("utf-8", errors="ignore") |
|
|
|
return current_text, stream_state |
|
|
|
except Exception as e: |
|
print(f"Stream error: {e}") |
|
return str(e), stream_state |
|
|
|
def create_app(): |
|
with gr.Blocks() as app: |
|
gr.Markdown("# Real-time Speech Recognition") |
|
language_choice = gr.Radio(choices=["English", "French", "German"], label="Select Language", value="English") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("File Transcription"): |
|
gr.Markdown("Upload an audio file for streaming transcription") |
|
file_input = gr.File(label="Audio File", type="filepath") |
|
file_output = gr.Textbox(label="Transcription") |
|
transcribe_btn = gr.Button("Transcribe") |
|
transcribe_btn.click(lambda: "", outputs=file_output).then( |
|
transcribe_audio_online_streaming, |
|
inputs=[file_input, language_choice], |
|
outputs=file_output |
|
) |
|
|
|
with gr.Tab("Live Microphone"): |
|
gr.Markdown("Speak into your microphone for real-time transcription") |
|
mic = gr.Audio( |
|
sources=["microphone"], |
|
streaming=True, |
|
type="numpy", |
|
label="Live Input", |
|
show_download_button=False |
|
) |
|
live_output = gr.Textbox(label="Live Transcription") |
|
state = gr.State() |
|
|
|
mic.stream( |
|
transcribe_microphone_stream, |
|
inputs=[mic, state, language_choice], |
|
outputs=[live_output, state], |
|
show_progress="hidden" |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = create_app() |
|
app.launch() |