transcriber / app.py
Zeptosec's picture
Update app.py
474feff
import gradio as gr
import whisper
import sys
import threading
from typing import List, Union
import tqdm
class ProgressListenerHandle:
def __init__(self, listener):
self.listener = listener
def __enter__(self):
register_thread_local_progress_listener(self.listener)
def __exit__(self, exc_type, exc_val, exc_tb):
unregister_thread_local_progress_listener(self.listener)
if exc_type is None:
self.listener.on_finished()
class _CustomProgressBar(tqdm.tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._current = self.n # Set the initial value
def update(self, n):
super().update(n)
# Because the progress bar might be disabled, we need to manually update the progress
self._current += n
# Inform listeners
listeners = _get_thread_local_listeners()
for listener in listeners:
listener.on_progress(self._current, self.total)
_thread_local = threading.local()
def _get_thread_local_listeners():
if not hasattr(_thread_local, 'listeners'):
_thread_local.listeners = []
return _thread_local.listeners
_hooked = False
def init_progress_hook():
global _hooked
if _hooked:
return
# Inject into tqdm.tqdm of Whisper, so we can see progress
import whisper.transcribe
transcribe_module = sys.modules['whisper.transcribe']
transcribe_module.tqdm.tqdm = _CustomProgressBar
_hooked = True
def register_thread_local_progress_listener(progress_listener):
# This is a workaround for the fact that the progress bar is not exposed in the API
init_progress_hook()
listeners = _get_thread_local_listeners()
listeners.append(progress_listener)
def unregister_thread_local_progress_listener(progress_listener):
listeners = _get_thread_local_listeners()
if progress_listener in listeners:
listeners.remove(progress_listener)
def create_progress_listener_handle(progress_listener):
return ProgressListenerHandle(progress_listener)
class PrintingProgressListener:
def __init__(self, progress):
self.progress = progress
def on_progress(self, current: Union[int, float], total: Union[int, float]):
self.progress(current / total, desc="Transcribing")
print(f"Progress: {current}/{total}")
def on_finished(self):
self.progress(1, desc="Transcribed!")
print("Finished")
import gc
import torch
from whisper.utils import get_writer
from random import random
models = ['base', 'small', 'medium', 'large']
output_formats = ["txt", "vtt", "srt", "tsv", "json"]
locModeltype = ""
locModel = None
def transcribe_audio(model,audio, progress=gr.Progress()):
global locModel
global locModeltype
try:
progress(0, desc="Starting...")
# If using a different model unload previous and load in a new one
if locModeltype != model:
locModeltype = model
del locModel
torch.cuda.empty_cache()
gc.collect()
progress(0, desc="Loading model...")
locModel = whisper.load_model(model)
progress(0, desc="Transcribing")
with create_progress_listener_handle(PrintingProgressListener(progress)) as listener:
result = locModel.transcribe(audio, verbose=False)
#path = f"/tmp/{oformat}{random()}"
#writr = get_writer(oformat, path)
#writr(result, path)
#with open(path, 'r') as f:
# rz = f.read()
# if rz == None:
# rz = result['text']
return f"language: {result['language']}\n\n{result['text']}"
except Exception as w:
raise gr.Error(f"Error: {str(w)}")
demo = gr.Interface(
fn=transcribe_audio,
inputs=[
gr.Dropdown(models, value=models[2], label="Model size", info="Model size determines the accuracy of the output text at the cost of speed"),
# gr.Dropdown(output_formats, value=output_formats[0], label="Output format", info="Format output text"),
# gr.Checkbox(value=False, label="Timestamps", info="Add timestampts to know when what was said"),
gr.Audio(label="Audio to transcribe",source='upload',type="filepath")
],
allow_flagging="never",
outputs="text")
demo.queue().launch()