Spaces:
Runtime error
Runtime error
import gradio as gr | |
import whisper | |
import sys | |
import threading | |
from typing import List, Union | |
import tqdm | |
class ProgressListenerHandle: | |
def __init__(self, listener): | |
self.listener = listener | |
def __enter__(self): | |
register_thread_local_progress_listener(self.listener) | |
def __exit__(self, exc_type, exc_val, exc_tb): | |
unregister_thread_local_progress_listener(self.listener) | |
if exc_type is None: | |
self.listener.on_finished() | |
class _CustomProgressBar(tqdm.tqdm): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self._current = self.n # Set the initial value | |
def update(self, n): | |
super().update(n) | |
# Because the progress bar might be disabled, we need to manually update the progress | |
self._current += n | |
# Inform listeners | |
listeners = _get_thread_local_listeners() | |
for listener in listeners: | |
listener.on_progress(self._current, self.total) | |
_thread_local = threading.local() | |
def _get_thread_local_listeners(): | |
if not hasattr(_thread_local, 'listeners'): | |
_thread_local.listeners = [] | |
return _thread_local.listeners | |
_hooked = False | |
def init_progress_hook(): | |
global _hooked | |
if _hooked: | |
return | |
# Inject into tqdm.tqdm of Whisper, so we can see progress | |
import whisper.transcribe | |
transcribe_module = sys.modules['whisper.transcribe'] | |
transcribe_module.tqdm.tqdm = _CustomProgressBar | |
_hooked = True | |
def register_thread_local_progress_listener(progress_listener): | |
# This is a workaround for the fact that the progress bar is not exposed in the API | |
init_progress_hook() | |
listeners = _get_thread_local_listeners() | |
listeners.append(progress_listener) | |
def unregister_thread_local_progress_listener(progress_listener): | |
listeners = _get_thread_local_listeners() | |
if progress_listener in listeners: | |
listeners.remove(progress_listener) | |
def create_progress_listener_handle(progress_listener): | |
return ProgressListenerHandle(progress_listener) | |
class PrintingProgressListener: | |
def __init__(self, progress): | |
self.progress = progress | |
def on_progress(self, current: Union[int, float], total: Union[int, float]): | |
self.progress(current / total, desc="Transcribing") | |
print(f"Progress: {current}/{total}") | |
def on_finished(self): | |
self.progress(1, desc="Transcribed!") | |
print("Finished") | |
import gc | |
import torch | |
from whisper.utils import get_writer | |
from random import random | |
models = ['base', 'small', 'medium', 'large'] | |
output_formats = ["txt", "vtt", "srt", "tsv", "json"] | |
locModeltype = "" | |
locModel = None | |
def transcribe_audio(model,audio, progress=gr.Progress()): | |
global locModel | |
global locModeltype | |
try: | |
progress(0, desc="Starting...") | |
# If using a different model unload previous and load in a new one | |
if locModeltype != model: | |
locModeltype = model | |
del locModel | |
torch.cuda.empty_cache() | |
gc.collect() | |
progress(0, desc="Loading model...") | |
locModel = whisper.load_model(model) | |
progress(0, desc="Transcribing") | |
with create_progress_listener_handle(PrintingProgressListener(progress)) as listener: | |
result = locModel.transcribe(audio, verbose=False) | |
#path = f"/tmp/{oformat}{random()}" | |
#writr = get_writer(oformat, path) | |
#writr(result, path) | |
#with open(path, 'r') as f: | |
# rz = f.read() | |
# if rz == None: | |
# rz = result['text'] | |
return f"language: {result['language']}\n\n{result['text']}" | |
except Exception as w: | |
raise gr.Error(f"Error: {str(w)}") | |
demo = gr.Interface( | |
fn=transcribe_audio, | |
inputs=[ | |
gr.Dropdown(models, value=models[2], label="Model size", info="Model size determines the accuracy of the output text at the cost of speed"), | |
# gr.Dropdown(output_formats, value=output_formats[0], label="Output format", info="Format output text"), | |
# gr.Checkbox(value=False, label="Timestamps", info="Add timestampts to know when what was said"), | |
gr.Audio(label="Audio to transcribe",source='upload',type="filepath") | |
], | |
allow_flagging="never", | |
outputs="text") | |
demo.queue().launch() |