FAMILIA commited on
Commit
bd77f79
·
1 Parent(s): a2af5c8

Add application file

Browse files
Files changed (17) hide show
  1. .gitignore +5 -0
  2. README.md +58 -4
  3. app-local.py +3 -0
  4. app-network.py +3 -0
  5. app-shared.py +3 -0
  6. app.py +256 -0
  7. cli.py +110 -0
  8. dockerfile +20 -0
  9. docs/options.md +78 -0
  10. requirements.txt +6 -0
  11. src/__init__.py +0 -0
  12. src/download.py +72 -0
  13. src/segments.py +55 -0
  14. src/utils.py +115 -0
  15. src/vad.py +477 -0
  16. tests/segments_test.py +48 -0
  17. tests/vad_test.py +66 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ flagged/
4
+ *.py[cod]
5
+ *$py.class
README.md CHANGED
@@ -1,12 +1,66 @@
1
  ---
2
- title: Vozparatexto
3
  emoji: ⚡
4
- colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 3.10.1
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Voz para Texto
3
  emoji: ⚡
4
+ colorFrom: pink
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.3.1
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ # Running Locally
16
+
17
+ To run this program locally, first install Python 3.9+ and Git. Then install Pytorch 10.1+ and all the other dependencies:
18
+ ```
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ Finally, run the full version (no audio length restrictions) of the app:
23
+ ```
24
+ python app-full.py
25
+ ```
26
+
27
+ You can also run the CLI interface, which is similar to Whisper's own CLI but also supports the following additional arguments:
28
+ ```
29
+ python cli.py \
30
+ [--vad {none,silero-vad,silero-vad-skip-gaps,silero-vad-expand-into-gaps,periodic-vad}] \
31
+ [--vad_merge_window VAD_MERGE_WINDOW] \
32
+ [--vad_max_merge_size VAD_MAX_MERGE_SIZE] \
33
+ [--vad_padding VAD_PADDING] \
34
+ [--vad_prompt_window VAD_PROMPT_WINDOW]
35
+ ```
36
+ In addition, you may also use URL's in addition to file paths as input.
37
+ ```
38
+ python cli.py --model large --vad silero-vad --language Japanese "https://www.youtube.com/watch?v=4cICErqqRSM"
39
+ ```
40
+
41
+ # Docker
42
+
43
+ To run it in Docker, first install Docker and optionally the NVIDIA Container Toolkit in order to use the GPU. Then
44
+ check out this repository and build an image:
45
+ ```
46
+ sudo docker build -t whisper-webui:1 .
47
+ ```
48
+
49
+ You can then start the WebUI with GPU support like so:
50
+ ```
51
+ sudo docker run -d --gpus=all -p 7860:7860 whisper-webui:1
52
+ ```
53
+
54
+ Leave out "--gpus=all" if you don't have access to a GPU with enough memory, and are fine with running it on the CPU only:
55
+ ```
56
+ sudo docker run -d -p 7860:7860 whisper-webui:1
57
+ ```
58
+
59
+ ## Caching
60
+
61
+ Note that the models themselves are currently not included in the Docker images, and will be downloaded on the demand.
62
+ To avoid this, bind the directory /root/.cache/whisper to some directory on the host (for instance /home/administrator/.cache/whisper), where you can (optionally)
63
+ prepopulate the directory with the different Whisper models.
64
+ ```
65
+ sudo docker run -d --gpus=all -p 7860:7860 --mount type=bind,source=/home/administrator/.cache/whisper,target=/root/.cache/whisper whisper-webui:1
66
+ ```
app-local.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ create_ui(-1)
app-network.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions, and make it available on the network
2
+ from app import create_ui
3
+ create_ui(-1, server_name="0.0.0.0")
app-shared.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Run the app with no audio file restrictions
2
+ from app import create_ui
3
+ create_ui(-1, share=True)
app.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator
2
+
3
+ from io import StringIO
4
+ import os
5
+ import pathlib
6
+ import tempfile
7
+
8
+ # External programs
9
+ import whisper
10
+ import ffmpeg
11
+
12
+ # UI
13
+ import gradio as gr
14
+
15
+ from src.download import ExceededMaximumDuration, download_url
16
+ from src.utils import slugify, write_srt, write_vtt
17
+ from src.vad import NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
18
+
19
+ # Limitations (set to -1 to disable)
20
+ DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
21
+
22
+ # Whether or not to automatically delete all uploaded files, to save disk space
23
+ DELETE_UPLOADED_FILES = True
24
+
25
+ # Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
26
+ MAX_FILE_PREFIX_LENGTH = 17
27
+
28
+ LANGUAGES = [
29
+ "English", "Chinese", "German", "Spanish", "Russian", "Korean",
30
+ "French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
31
+ "Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
32
+ "Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
33
+ "Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
34
+ "Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
35
+ "Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
36
+ "Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
37
+ "Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
38
+ "Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
39
+ "Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
40
+ "Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
41
+ "Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
42
+ "Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
43
+ "Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
44
+ "Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
45
+ "Hausa", "Bashkir", "Javanese", "Sundanese"
46
+ ]
47
+
48
+ class WhisperTranscriber:
49
+ def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
50
+ self.model_cache = dict()
51
+
52
+ self.vad_model = None
53
+ self.inputAudioMaxDuration = inputAudioMaxDuration
54
+ self.deleteUploadedFiles = deleteUploadedFiles
55
+
56
+ def transcribe_webui(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow):
57
+ try:
58
+ source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
59
+
60
+ try:
61
+ selectedLanguage = languageName.lower() if len(languageName) > 0 else None
62
+ selectedModel = modelName if modelName is not None else "base"
63
+
64
+ model = self.model_cache.get(selectedModel, None)
65
+
66
+ if not model:
67
+ model = whisper.load_model(selectedModel)
68
+ self.model_cache[selectedModel] = model
69
+
70
+ # Execute whisper
71
+ result = self.transcribe_file(model, source, selectedLanguage, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
72
+
73
+ # Write result
74
+ downloadDirectory = tempfile.mkdtemp()
75
+
76
+ filePrefix = slugify(sourceName, allow_unicode=True)
77
+ download, text, vtt = self.write_result(result, filePrefix, downloadDirectory)
78
+
79
+ return download, text, vtt
80
+
81
+ finally:
82
+ # Cleanup source
83
+ if self.deleteUploadedFiles:
84
+ print("Deleting source file " + source)
85
+ os.remove(source)
86
+
87
+ except ExceededMaximumDuration as e:
88
+ return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
89
+
90
+ def transcribe_file(self, model: whisper.Whisper, audio_path: str, language: str, task: str = None, vad: str = None,
91
+ vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1, **decodeOptions: dict):
92
+
93
+ initial_prompt = decodeOptions.pop('initial_prompt', None)
94
+
95
+ if ('task' in decodeOptions):
96
+ task = decodeOptions.pop('task')
97
+
98
+ # Callable for processing an audio file
99
+ whisperCallable = lambda audio, segment_index, prompt, detected_language : model.transcribe(audio, \
100
+ language=language if language else detected_language, task=task, \
101
+ initial_prompt=self._concat_prompt(initial_prompt, prompt) if segment_index == 0 else prompt, \
102
+ **decodeOptions)
103
+
104
+ # The results
105
+ if (vad == 'silero-vad'):
106
+ # Silero VAD where non-speech gaps are transcribed
107
+ process_gaps = self._create_silero_config(NonSpeechStrategy.CREATE_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
108
+ result = self.vad_model.transcribe(audio_path, whisperCallable, process_gaps)
109
+ elif (vad == 'silero-vad-skip-gaps'):
110
+ # Silero VAD where non-speech gaps are simply ignored
111
+ skip_gaps = self._create_silero_config(NonSpeechStrategy.SKIP, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
112
+ result = self.vad_model.transcribe(audio_path, whisperCallable, skip_gaps)
113
+ elif (vad == 'silero-vad-expand-into-gaps'):
114
+ # Use Silero VAD where speech-segments are expanded into non-speech gaps
115
+ expand_gaps = self._create_silero_config(NonSpeechStrategy.EXPAND_SEGMENT, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow)
116
+ result = self.vad_model.transcribe(audio_path, whisperCallable, expand_gaps)
117
+ elif (vad == 'periodic-vad'):
118
+ # Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
119
+ # it may create a break in the middle of a sentence, causing some artifacts.
120
+ periodic_vad = VadPeriodicTranscription()
121
+ result = periodic_vad.transcribe(audio_path, whisperCallable, PeriodicTranscriptionConfig(periodic_duration=vadMaxMergeSize, max_prompt_window=vadPromptWindow))
122
+ else:
123
+ # Default VAD
124
+ result = whisperCallable(audio_path, 0, None, None)
125
+
126
+ return result
127
+
128
+ def _concat_prompt(self, prompt1, prompt2):
129
+ if (prompt1 is None):
130
+ return prompt2
131
+ elif (prompt2 is None):
132
+ return prompt1
133
+ else:
134
+ return prompt1 + " " + prompt2
135
+
136
+ def _create_silero_config(self, non_speech_strategy: NonSpeechStrategy, vadMergeWindow: float = 5, vadMaxMergeSize: float = 150, vadPadding: float = 1, vadPromptWindow: float = 1):
137
+ # Use Silero VAD
138
+ if (self.vad_model is None):
139
+ self.vad_model = VadSileroTranscription()
140
+
141
+ config = TranscriptionConfig(non_speech_strategy = non_speech_strategy,
142
+ max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
143
+ segment_padding_left=vadPadding, segment_padding_right=vadPadding,
144
+ max_prompt_window=vadPromptWindow)
145
+
146
+ return config
147
+
148
+ def write_result(self, result: dict, source_name: str, output_dir: str):
149
+ if not os.path.exists(output_dir):
150
+ os.makedirs(output_dir)
151
+
152
+ text = result["text"]
153
+ language = result["language"]
154
+ languageMaxLineWidth = self.__get_max_line_width(language)
155
+
156
+ print("Max line width " + str(languageMaxLineWidth))
157
+ vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
158
+ srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
159
+
160
+ output_files = []
161
+ output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
162
+ output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
163
+ output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
164
+
165
+ return output_files, text, vtt
166
+
167
+ def clear_cache(self):
168
+ self.model_cache = dict()
169
+ self.vad_model = None
170
+
171
+ def __get_source(self, urlData, uploadFile, microphoneData):
172
+ if urlData:
173
+ # Download from YouTube
174
+ source = download_url(urlData, self.inputAudioMaxDuration)[0]
175
+ else:
176
+ # File input
177
+ source = uploadFile if uploadFile is not None else microphoneData
178
+
179
+ if self.inputAudioMaxDuration > 0:
180
+ # Calculate audio length
181
+ audioDuration = ffmpeg.probe(source)["format"]["duration"]
182
+
183
+ if float(audioDuration) > self.inputAudioMaxDuration:
184
+ raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")
185
+
186
+ file_path = pathlib.Path(source)
187
+ sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
188
+
189
+ return source, sourceName
190
+
191
+ def __get_max_line_width(self, language: str) -> int:
192
+ if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
193
+ # Chinese characters and kana are wider, so limit line length to 40 characters
194
+ return 40
195
+ else:
196
+ # TODO: Add more languages
197
+ # 80 latin characters should fit on a 1080p/720p screen
198
+ return 80
199
+
200
+ def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
201
+ segmentStream = StringIO()
202
+
203
+ if format == 'vtt':
204
+ write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
205
+ elif format == 'srt':
206
+ write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
207
+ else:
208
+ raise Exception("Unknown format " + format)
209
+
210
+ segmentStream.seek(0)
211
+ return segmentStream.read()
212
+
213
+ def __create_file(self, text: str, directory: str, fileName: str) -> str:
214
+ # Write the text to a file
215
+ with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
216
+ file.write(text)
217
+
218
+ return file.name
219
+
220
+
221
+ def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
222
+ ui = WhisperTranscriber(inputAudioMaxDuration)
223
+
224
+ ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
225
+ ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
226
+ ui_description += " as well as speech translation and language identification. "
227
+
228
+ ui_description += "\n\n\n\nFor longer audio files (>10 minutes) not in English, it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
229
+
230
+ if inputAudioMaxDuration > 0:
231
+ ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
232
+
233
+ ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
234
+
235
+ demo = gr.Interface(fn=ui.transcribe_webui, description=ui_description, article=ui_article, inputs=[
236
+ gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
237
+ gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
238
+ gr.Text(label="URL (YouTube, etc.)"),
239
+ gr.Audio(source="upload", type="filepath", label="Upload Audio"),
240
+ gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
241
+ gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
242
+ gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], label="VAD"),
243
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
244
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=30),
245
+ gr.Number(label="VAD - Padding (s)", precision=None, value=1),
246
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=3)
247
+ ], outputs=[
248
+ gr.File(label="Download"),
249
+ gr.Text(label="Transcription"),
250
+ gr.Text(label="Segments")
251
+ ])
252
+
253
+ demo.launch(share=share, server_name=server_name)
254
+
255
+ if __name__ == '__main__':
256
+ create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION)
cli.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ from urllib.parse import urlparse
5
+ import warnings
6
+ import numpy as np
7
+
8
+ import whisper
9
+
10
+ import torch
11
+ from app import LANGUAGES, WhisperTranscriber
12
+ from src.download import download_url
13
+
14
+ from src.utils import optional_float, optional_int, str2bool
15
+
16
+
17
+ def cli():
18
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
19
+ parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
20
+ parser.add_argument("--model", default="small", choices=["tiny", "base", "small", "medium", "large"], help="name of the Whisper model to use")
21
+ parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
22
+ parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
23
+ parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
24
+ parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
25
+
26
+ parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
27
+ parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES), help="language spoken in the audio, specify None to perform language detection")
28
+
29
+ parser.add_argument("--vad", type=str, default="none", choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], help="The voice activity detection algorithm to use")
30
+ parser.add_argument("--vad_merge_window", type=optional_float, default=5, help="The window size (in seconds) to merge voice segments")
31
+ parser.add_argument("--vad_max_merge_size", type=optional_float, default=30, help="The maximum size (in seconds) of a voice segment")
32
+ parser.add_argument("--vad_padding", type=optional_float, default=1, help="The padding (in seconds) to add to each voice segment")
33
+ parser.add_argument("--vad_prompt_window", type=optional_float, default=3, help="The window size of the prompt to pass to Whisper")
34
+
35
+ parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
36
+ parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
37
+ parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
38
+ parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
39
+ parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple lengt normalization by default")
40
+
41
+ parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
42
+ parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
43
+ parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
44
+ parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
45
+
46
+ parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
47
+ parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
48
+ parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
49
+ parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
50
+
51
+ args = parser.parse_args().__dict__
52
+ model_name: str = args.pop("model")
53
+ model_dir: str = args.pop("model_dir")
54
+ output_dir: str = args.pop("output_dir")
55
+ device: str = args.pop("device")
56
+ os.makedirs(output_dir, exist_ok=True)
57
+
58
+ if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
59
+ warnings.warn(f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead.")
60
+ args["language"] = "en"
61
+
62
+ temperature = args.pop("temperature")
63
+ temperature_increment_on_fallback = args.pop("temperature_increment_on_fallback")
64
+ if temperature_increment_on_fallback is not None:
65
+ temperature = tuple(np.arange(temperature, 1.0 + 1e-6, temperature_increment_on_fallback))
66
+ else:
67
+ temperature = [temperature]
68
+
69
+ vad = args.pop("vad")
70
+ vad_merge_window = args.pop("vad_merge_window")
71
+ vad_max_merge_size = args.pop("vad_max_merge_size")
72
+ vad_padding = args.pop("vad_padding")
73
+ vad_prompt_window = args.pop("vad_prompt_window")
74
+
75
+ model = whisper.load_model(model_name, device=device, download_root=model_dir)
76
+ transcriber = WhisperTranscriber(deleteUploadedFiles=False)
77
+
78
+ for audio_path in args.pop("audio"):
79
+ sources = []
80
+
81
+ # Detect URL and download the audio
82
+ if (uri_validator(audio_path)):
83
+ # Download from YouTube/URL directly
84
+ for source_path in download_url(audio_path, maxDuration=-1, destinationDirectory=output_dir, playlistItems=None):
85
+ source_name = os.path.basename(source_path)
86
+ sources.append({ "path": source_path, "name": source_name })
87
+ else:
88
+ sources.append({ "path": audio_path, "name": os.path.basename(audio_path) })
89
+
90
+ for source in sources:
91
+ source_path = source["path"]
92
+ source_name = source["name"]
93
+
94
+ result = transcriber.transcribe_file(model, source_path, temperature=temperature,
95
+ vad=vad, vadMergeWindow=vad_merge_window, vadMaxMergeSize=vad_max_merge_size,
96
+ vadPadding=vad_padding, vadPromptWindow=vad_prompt_window, **args)
97
+
98
+ transcriber.write_result(result, source_name, output_dir)
99
+
100
+ transcriber.clear_cache()
101
+
102
+ def uri_validator(x):
103
+ try:
104
+ result = urlparse(x)
105
+ return all([result.scheme, result.netloc])
106
+ except:
107
+ return False
108
+
109
+ if __name__ == '__main__':
110
+ cli()
dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM huggingface/transformers-pytorch-gpu
2
+ EXPOSE 7860
3
+
4
+ ADD . /opt/whisper-webui/
5
+
6
+ # Latest version of transformers-pytorch-gpu seems to lack tk.
7
+ # Further, pip install fails, so we must upgrade pip first.
8
+ RUN apt-get -y install python3-tk
9
+ RUN python3 -m pip install --upgrade pip &&\
10
+ python3 -m pip install -r /opt/whisper-webui/requirements.txt
11
+
12
+ # Note: Models will be downloaded on demand to the directory /root/.cache/whisper.
13
+ # You can also bind this directory in the container to somewhere on the host.
14
+
15
+ # To be able to see logs in real time
16
+ ENV PYTHONUNBUFFERED=1
17
+
18
+ WORKDIR /opt/whisper-webui/
19
+ ENTRYPOINT ["python3"]
20
+ CMD ["app-network.py"]
docs/options.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Options
2
+ To transcribe or translate an audio file, you can either copy an URL from a website (all [websites](https://github.com/yt-dlp/yt-dlp/blob/master/supportedsites.md)
3
+ supported by YT-DLP will work, including YouTube). Otherwise, upload an audio file (choose "All Files (*.*)"
4
+ in the file selector to select any file type, including video files) or use the microphone.
5
+
6
+ For longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option.
7
+
8
+ ## Model
9
+ Select the model that Whisper will use to transcribe the audio:
10
+
11
+ | Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
12
+ |--------|------------|--------------------|--------------------|---------------|----------------|
13
+ | tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
14
+ | base | 74 M | base.en | base | ~1 GB | ~16x |
15
+ | small | 244 M | small.en | small | ~2 GB | ~6x |
16
+ | medium | 769 M | medium.en | medium | ~5 GB | ~2x |
17
+ | large | 1550 M | N/A | large | ~10 GB | 1x |
18
+
19
+ ## Language
20
+
21
+ Select the language, or leave it empty for Whisper to automatically detect it.
22
+
23
+ Note that if the selected language and the language in the audio differs, Whisper may start to translate the audio to the selected
24
+ language. For instance, if the audio is in English but you select Japaneese, the model may translate the audio to Japanese.
25
+
26
+ ## Inputs
27
+ The options "URL (YouTube, etc.)", "Upload Audio" or "Micriphone Input" allows you to send an audio input to the model.
28
+
29
+ Note that the UI will only process the first valid input - i.e. if you enter both an URL and upload an audio, it will only process
30
+ the URL.
31
+
32
+ ## Task
33
+ Select the task - either "transcribe" to transcribe the audio to text, or "translate" to translate it to English.
34
+
35
+ ## Vad
36
+ Using a VAD will improve the timing accuracy of each transcribed line, as well as prevent Whisper getting into an infinite
37
+ loop detecting the same sentence over and over again. The downside is that this may be at a cost to text accuracy, especially
38
+ with regards to unique words or names that appear in the audio. You can compensate for this by increasing the prompt window.
39
+
40
+ Note that English is very well handled by Whisper, and it's less susceptible to issues surrounding bad timings and infinite loops.
41
+ So you may only need to use a VAD for other languages, such as Japanese, or when the audio is very long.
42
+
43
+ * none
44
+ * Run whisper on the entire audio input
45
+ * silero-vad
46
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Whisper is also run
47
+ on the gaps between each speech section, by either expanding the section up to the max merge size, or running Whisper independently
48
+ on the non-speech section.
49
+ * silero-vad-expand-into-gaps
50
+ * Use Silero VAD to detect sections that contain speech, and run Whisper on independently on each section. Each spech section will be expanded
51
+ such that they cover any adjacent non-speech sections. For instance, if an audio file of one minute contains the speech sections
52
+ 00:00 - 00:10 (A) and 00:30 - 00:40 (B), the first section (A) will be expanded to 00:00 - 00:30, and (B) will be expanded to 00:30 - 00:60.
53
+ * silero-vad-skip-gaps
54
+ * As above, but sections that doesn't contain speech according to Silero will be skipped. This will be slightly faster, but
55
+ may cause dialogue to be skipped.
56
+ * periodic-vad
57
+ * Create sections of speech every 'VAD - Max Merge Size' seconds. This is very fast and simple, but will potentially break
58
+ a sentence or word in two.
59
+
60
+ ## VAD - Merge Window
61
+ If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged.
62
+
63
+ ## VAD - Max Merge Size (s)
64
+ Disables merging of adjacent speech sections if they are this number of seconds long.
65
+
66
+ ## VAD - Padding (s)
67
+ The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number
68
+ larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of
69
+ a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp
70
+ to each transcribed line. The default value is 1 second.
71
+
72
+ ## VAD - Prompt Window (s)
73
+ The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this
74
+ number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at
75
+ 10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds).
76
+
77
+ Note that detected lines in gaps between speech sections will not be included in the prompt
78
+ (if silero-vad or silero-vad-expand-into-gaps) is used.
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ git+https://github.com/openai/whisper.git
2
+ transformers
3
+ ffmpeg-python==0.2.0
4
+ gradio
5
+ yt-dlp
6
+ torchaudio
src/__init__.py ADDED
File without changes
src/download.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tempfile import mkdtemp
2
+ from typing import List
3
+ from yt_dlp import YoutubeDL
4
+
5
+ import yt_dlp
6
+ from yt_dlp.postprocessor import PostProcessor
7
+
8
+ class FilenameCollectorPP(PostProcessor):
9
+ def __init__(self):
10
+ super(FilenameCollectorPP, self).__init__(None)
11
+ self.filenames = []
12
+
13
+ def run(self, information):
14
+ self.filenames.append(information["filepath"])
15
+ return [], information
16
+
17
+ def download_url(url: str, maxDuration: int = None, destinationDirectory: str = None, playlistItems: str = "1") -> List[str]:
18
+ try:
19
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate=None, destinationDirectory=destinationDirectory, playlistItems=playlistItems)
20
+ except yt_dlp.utils.DownloadError as e:
21
+ # In case of an OS error, try again with a different output template
22
+ if e.msg and e.msg.find("[Errno 36] File name too long") >= 0:
23
+ return _perform_download(url, maxDuration=maxDuration, outputTemplate="%(title).10s %(id)s.%(ext)s")
24
+ pass
25
+
26
+ def _perform_download(url: str, maxDuration: int = None, outputTemplate: str = None, destinationDirectory: str = None, playlistItems: str = "1"):
27
+ # Create a temporary directory to store the downloaded files
28
+ if destinationDirectory is None:
29
+ destinationDirectory = mkdtemp()
30
+
31
+ ydl_opts = {
32
+ "format": "bestaudio/best",
33
+ 'paths': {
34
+ 'home': destinationDirectory
35
+ }
36
+ }
37
+ if (playlistItems):
38
+ ydl_opts['playlist_items'] = playlistItems
39
+
40
+ # Add output template if specified
41
+ if outputTemplate:
42
+ ydl_opts['outtmpl'] = outputTemplate
43
+
44
+ filename_collector = FilenameCollectorPP()
45
+
46
+ with YoutubeDL(ydl_opts) as ydl:
47
+ if maxDuration and maxDuration > 0:
48
+ info = ydl.extract_info(url, download=False)
49
+ duration = info['duration']
50
+
51
+ if duration >= maxDuration:
52
+ raise ExceededMaximumDuration(videoDuration=duration, maxDuration=maxDuration, message="Video is too long")
53
+
54
+ ydl.add_post_processor(filename_collector)
55
+ ydl.download([url])
56
+
57
+ if len(filename_collector.filenames) <= 0:
58
+ raise Exception("Cannot download " + url)
59
+
60
+ result = []
61
+
62
+ for filename in filename_collector.filenames:
63
+ result.append(filename)
64
+ print("Downloaded " + filename)
65
+
66
+ return result
67
+
68
+ class ExceededMaximumDuration(Exception):
69
+ def __init__(self, videoDuration, maxDuration, message):
70
+ self.videoDuration = videoDuration
71
+ self.maxDuration = maxDuration
72
+ super().__init__(message)
src/segments.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List
2
+
3
+ import copy
4
+
5
+ def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
6
+ result = []
7
+
8
+ if len(timestamps) == 0:
9
+ return result
10
+ if max_merge_size is None:
11
+ return timestamps
12
+
13
+ if padding_left is None:
14
+ padding_left = 0
15
+ if padding_right is None:
16
+ padding_right = 0
17
+
18
+ processed_time = 0
19
+ current_segment = None
20
+
21
+ for i in range(len(timestamps)):
22
+ next_segment = timestamps[i]
23
+
24
+ delta = next_segment['start'] - processed_time
25
+
26
+ # Note that segments can still be longer than the max merge size, they just won't be merged in that case
27
+ if current_segment is None or (merge_window is not None and delta > merge_window) \
28
+ or next_segment['end'] - current_segment['start'] > max_merge_size:
29
+ # Finish the current segment
30
+ if current_segment is not None:
31
+ # Add right padding
32
+ finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
33
+ current_segment['end'] += finish_padding
34
+ delta -= finish_padding
35
+
36
+ result.append(current_segment)
37
+
38
+ # Start a new segment
39
+ current_segment = copy.deepcopy(next_segment)
40
+
41
+ # Pad the segment
42
+ current_segment['start'] = current_segment['start'] - min(padding_left, delta)
43
+ processed_time = current_segment['end']
44
+
45
+ else:
46
+ # Merge the segment
47
+ current_segment['end'] = next_segment['end']
48
+ processed_time = current_segment['end']
49
+
50
+ # Add the last segment
51
+ if current_segment is not None:
52
+ current_segment['end'] += padding_right
53
+ result.append(current_segment)
54
+
55
+ return result
src/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import textwrap
2
+ import unicodedata
3
+ import re
4
+
5
+ import zlib
6
+ from typing import Iterator, TextIO
7
+
8
+
9
+ def exact_div(x, y):
10
+ assert x % y == 0
11
+ return x // y
12
+
13
+
14
+ def str2bool(string):
15
+ str2val = {"True": True, "False": False}
16
+ if string in str2val:
17
+ return str2val[string]
18
+ else:
19
+ raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
20
+
21
+
22
+ def optional_int(string):
23
+ return None if string == "None" else int(string)
24
+
25
+
26
+ def optional_float(string):
27
+ return None if string == "None" else float(string)
28
+
29
+
30
+ def compression_ratio(text) -> float:
31
+ return len(text) / len(zlib.compress(text.encode("utf-8")))
32
+
33
+
34
+ def format_timestamp(seconds: float, always_include_hours: bool = False, fractionalSeperator: str = '.'):
35
+ assert seconds >= 0, "non-negative timestamp expected"
36
+ milliseconds = round(seconds * 1000.0)
37
+
38
+ hours = milliseconds // 3_600_000
39
+ milliseconds -= hours * 3_600_000
40
+
41
+ minutes = milliseconds // 60_000
42
+ milliseconds -= minutes * 60_000
43
+
44
+ seconds = milliseconds // 1_000
45
+ milliseconds -= seconds * 1_000
46
+
47
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
48
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{fractionalSeperator}{milliseconds:03d}"
49
+
50
+
51
+ def write_txt(transcript: Iterator[dict], file: TextIO):
52
+ for segment in transcript:
53
+ print(segment['text'].strip(), file=file, flush=True)
54
+
55
+
56
+ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
57
+ print("WEBVTT\n", file=file)
58
+ for segment in transcript:
59
+ text = process_text(segment['text'], maxLineWidth).replace('-->', '->')
60
+
61
+ print(
62
+ f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
63
+ f"{text}\n",
64
+ file=file,
65
+ flush=True,
66
+ )
67
+
68
+
69
+ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
70
+ """
71
+ Write a transcript to a file in SRT format.
72
+ Example usage:
73
+ from pathlib import Path
74
+ from whisper.utils import write_srt
75
+ result = transcribe(model, audio_path, temperature=temperature, **args)
76
+ # save SRT
77
+ audio_basename = Path(audio_path).stem
78
+ with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
79
+ write_srt(result["segments"], file=srt)
80
+ """
81
+ for i, segment in enumerate(transcript, start=1):
82
+ text = process_text(segment['text'].strip(), maxLineWidth).replace('-->', '->')
83
+
84
+ # write srt lines
85
+ print(
86
+ f"{i}\n"
87
+ f"{format_timestamp(segment['start'], always_include_hours=True, fractionalSeperator=',')} --> "
88
+ f"{format_timestamp(segment['end'], always_include_hours=True, fractionalSeperator=',')}\n"
89
+ f"{text}\n",
90
+ file=file,
91
+ flush=True,
92
+ )
93
+
94
+ def process_text(text: str, maxLineWidth=None):
95
+ if (maxLineWidth is None or maxLineWidth < 0):
96
+ return text
97
+
98
+ lines = textwrap.wrap(text, width=maxLineWidth, tabsize=4)
99
+ return '\n'.join(lines)
100
+
101
+ def slugify(value, allow_unicode=False):
102
+ """
103
+ Taken from https://github.com/django/django/blob/master/django/utils/text.py
104
+ Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated
105
+ dashes to single dashes. Remove characters that aren't alphanumerics,
106
+ underscores, or hyphens. Convert to lowercase. Also strip leading and
107
+ trailing whitespace, dashes, and underscores.
108
+ """
109
+ value = str(value)
110
+ if allow_unicode:
111
+ value = unicodedata.normalize('NFKC', value)
112
+ else:
113
+ value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
114
+ value = re.sub(r'[^\w\s-]', '', value.lower())
115
+ return re.sub(r'[-\s]+', '-', value).strip('-_')
src/vad.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from collections import Counter, deque
3
+
4
+ from typing import Any, Deque, Iterator, List, Dict
5
+
6
+ from pprint import pprint
7
+
8
+ from src.segments import merge_timestamps
9
+
10
+ # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
11
+ try:
12
+ import tensorflow as tf
13
+ except ModuleNotFoundError:
14
+ # Error handling
15
+ pass
16
+
17
+ import torch
18
+
19
+ import ffmpeg
20
+ import numpy as np
21
+
22
+ from src.utils import format_timestamp
23
+ from enum import Enum
24
+
25
+ class NonSpeechStrategy(Enum):
26
+ """
27
+ Ignore non-speech frames segments.
28
+ """
29
+ SKIP = 1
30
+ """
31
+ Just treat non-speech segments as speech.
32
+ """
33
+ CREATE_SEGMENT = 2
34
+ """
35
+ Expand speech segments into subsequent non-speech segments.
36
+ """
37
+ EXPAND_SEGMENT = 3
38
+
39
+ # Defaults for Silero
40
+ SPEECH_TRESHOLD = 0.3
41
+
42
+ # Minimum size of segments to process
43
+ MIN_SEGMENT_DURATION = 1
44
+
45
+ # The maximum time for texts from old segments to be used in the next segment
46
+ MAX_PROMPT_WINDOW = 0 # seconds (0 = disabled)
47
+ PROMPT_NO_SPEECH_PROB = 0.1 # Do not pass the text from segments with a no speech probability higher than this
48
+
49
+ VAD_MAX_PROCESSING_CHUNK = 60 * 60 # 60 minutes of audio
50
+
51
+ class TranscriptionConfig(ABC):
52
+ def __init__(self, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
53
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
54
+ max_merge_size: float = None, max_prompt_window: float = None):
55
+ self.non_speech_strategy = non_speech_strategy
56
+ self.segment_padding_left = segment_padding_left
57
+ self.segment_padding_right = segment_padding_right
58
+ self.max_silent_period = max_silent_period
59
+ self.max_merge_size = max_merge_size
60
+ self.max_prompt_window = max_prompt_window
61
+
62
+ class PeriodicTranscriptionConfig(TranscriptionConfig):
63
+ def __init__(self, periodic_duration: float, non_speech_strategy: NonSpeechStrategy = NonSpeechStrategy.SKIP,
64
+ segment_padding_left: float = None, segment_padding_right = None, max_silent_period: float = None,
65
+ max_merge_size: float = None, max_prompt_window: float = None):
66
+ super().__init__(non_speech_strategy, segment_padding_left, segment_padding_right, max_silent_period, max_merge_size, max_prompt_window)
67
+ self.periodic_duration = periodic_duration
68
+
69
+ class AbstractTranscription(ABC):
70
+ def __init__(self, sampling_rate: int = 16000):
71
+ self.sampling_rate = sampling_rate
72
+
73
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
74
+ return load_audio(str, self.sampling_rate, start_time, duration)
75
+
76
+ @abstractmethod
77
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
78
+ """
79
+ Get the start and end timestamps of the sections that should be transcribed by this VAD method.
80
+
81
+ Parameters
82
+ ----------
83
+ audio: str
84
+ The audio file.
85
+ config: TranscriptionConfig
86
+ The transcription configuration.
87
+
88
+ Returns
89
+ -------
90
+ A list of start and end timestamps, in fractional seconds.
91
+ """
92
+ return
93
+
94
+ def transcribe(self, audio: str, whisperCallable, config: TranscriptionConfig):
95
+ """
96
+ Transcribe the given audo file.
97
+
98
+ Parameters
99
+ ----------
100
+ audio: str
101
+ The audio file.
102
+
103
+ whisperCallable: Callable[[Union[str, np.ndarray, torch.Tensor], int, str, str], dict[str, Union[dict, Any]]]
104
+ The callback that is used to invoke Whisper on an audio file/buffer. The first parameter is the audio file/buffer,
105
+ the second parameter is an optional text prompt, and the last is the current detected language. The return value is the result of the Whisper call.
106
+
107
+ Returns
108
+ -------
109
+ A list of start and end timestamps, in fractional seconds.
110
+ """
111
+
112
+ # get speech timestamps from full audio file
113
+ seconds_timestamps = self.get_transcribe_timestamps(audio, config)
114
+
115
+ #for seconds_timestamp in seconds_timestamps:
116
+ # print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
117
+
118
+ merged = merge_timestamps(seconds_timestamps, config.max_silent_period, config.max_merge_size, config.segment_padding_left, config.segment_padding_right)
119
+
120
+ # A deque of transcribed segments that is passed to the next segment as a prompt
121
+ prompt_window = deque()
122
+
123
+ print("Timestamps:")
124
+ pprint(merged)
125
+
126
+ if config.non_speech_strategy != NonSpeechStrategy.SKIP:
127
+ max_audio_duration = get_audio_duration(audio)
128
+
129
+ # Expand segments to include the gaps between them
130
+ if (config.non_speech_strategy == NonSpeechStrategy.CREATE_SEGMENT):
131
+ # When we have a prompt window, we create speech segments betwen each segment if we exceed the merge size
132
+ merged = self.fill_gaps(merged, total_duration=max_audio_duration, max_expand_size=config.max_merge_size)
133
+ elif config.non_speech_strategy == NonSpeechStrategy.EXPAND_SEGMENT:
134
+ # With no prompt window, it is better to just expand the segments (this effectively passes the prompt to the next segment)
135
+ merged = self.expand_gaps(merged, total_duration=max_audio_duration)
136
+ else:
137
+ raise Exception("Unknown non-speech strategy: " + str(config.non_speech_strategy))
138
+
139
+ print("Transcribing non-speech:")
140
+ pprint(merged)
141
+
142
+ result = {
143
+ 'text': "",
144
+ 'segments': [],
145
+ 'language': ""
146
+ }
147
+ languageCounter = Counter()
148
+ detected_language = None
149
+
150
+ segment_index = -1
151
+
152
+ # For each time segment, run whisper
153
+ for segment in merged:
154
+ segment_index += 1
155
+ segment_start = segment['start']
156
+ segment_end = segment['end']
157
+ segment_expand_amount = segment.get('expand_amount', 0)
158
+ segment_gap = segment.get('gap', False)
159
+
160
+ segment_duration = segment_end - segment_start
161
+
162
+ if segment_duration < MIN_SEGMENT_DURATION:
163
+ continue;
164
+
165
+ # Audio to run on Whisper
166
+ segment_audio = self.get_audio_segment(audio, start_time = str(segment_start), duration = str(segment_duration))
167
+ # Previous segments to use as a prompt
168
+ segment_prompt = ' '.join([segment['text'] for segment in prompt_window]) if len(prompt_window) > 0 else None
169
+
170
+ # Detected language
171
+ detected_language = languageCounter.most_common(1)[0][0] if len(languageCounter) > 0 else None
172
+
173
+ print("Running whisper from ", format_timestamp(segment_start), " to ", format_timestamp(segment_end), ", duration: ",
174
+ segment_duration, "expanded: ", segment_expand_amount, "prompt: ", segment_prompt, "language: ", detected_language)
175
+ segment_result = whisperCallable(segment_audio, segment_index, segment_prompt, detected_language)
176
+
177
+ adjusted_segments = self.adjust_timestamp(segment_result["segments"], adjust_seconds=segment_start, max_source_time=segment_duration)
178
+
179
+ # Propagate expand amount to the segments
180
+ if (segment_expand_amount > 0):
181
+ segment_without_expansion = segment_duration - segment_expand_amount
182
+
183
+ for adjusted_segment in adjusted_segments:
184
+ adjusted_segment_end = adjusted_segment['end']
185
+
186
+ # Add expand amount if the segment got expanded
187
+ if (adjusted_segment_end > segment_without_expansion):
188
+ adjusted_segment["expand_amount"] = adjusted_segment_end - segment_without_expansion
189
+
190
+ # Append to output
191
+ result['text'] += segment_result['text']
192
+ result['segments'].extend(adjusted_segments)
193
+
194
+ # Increment detected language
195
+ if not segment_gap:
196
+ languageCounter[segment_result['language']] += 1
197
+
198
+ # Update prompt window
199
+ self.__update_prompt_window(prompt_window, adjusted_segments, segment_end, segment_gap, config)
200
+
201
+ if detected_language is not None:
202
+ result['language'] = detected_language
203
+
204
+ return result
205
+
206
+ def __update_prompt_window(self, prompt_window: Deque, adjusted_segments: List, segment_end: float, segment_gap: bool, config: TranscriptionConfig):
207
+ if (config.max_prompt_window is not None and config.max_prompt_window > 0):
208
+ # Add segments to the current prompt window (unless it is a speech gap)
209
+ if not segment_gap:
210
+ for segment in adjusted_segments:
211
+ if segment.get('no_speech_prob', 0) <= PROMPT_NO_SPEECH_PROB:
212
+ prompt_window.append(segment)
213
+
214
+ while (len(prompt_window) > 0):
215
+ first_end_time = prompt_window[0].get('end', 0)
216
+ # Time expanded in the segments should be discounted from the prompt window
217
+ first_expand_time = prompt_window[0].get('expand_amount', 0)
218
+
219
+ if (first_end_time - first_expand_time < segment_end - config.max_prompt_window):
220
+ prompt_window.popleft()
221
+ else:
222
+ break
223
+
224
+ def include_gaps(self, segments: Iterator[dict], min_gap_length: float, total_duration: float):
225
+ result = []
226
+ last_end_time = 0
227
+
228
+ for segment in segments:
229
+ segment_start = float(segment['start'])
230
+ segment_end = float(segment['end'])
231
+
232
+ if (last_end_time != segment_start):
233
+ delta = segment_start - last_end_time
234
+
235
+ if (min_gap_length is None or delta >= min_gap_length):
236
+ result.append( { 'start': last_end_time, 'end': segment_start, 'gap': True } )
237
+
238
+ last_end_time = segment_end
239
+ result.append(segment)
240
+
241
+ # Also include total duration if specified
242
+ if (total_duration is not None and last_end_time < total_duration):
243
+ delta = total_duration - segment_start
244
+
245
+ if (min_gap_length is None or delta >= min_gap_length):
246
+ result.append( { 'start': last_end_time, 'end': total_duration, 'gap': True } )
247
+
248
+ return result
249
+
250
+ # Expand the end time of each segment to the start of the next segment
251
+ def expand_gaps(self, segments: List[Dict[str, Any]], total_duration: float):
252
+ result = []
253
+
254
+ if len(segments) == 0:
255
+ return result
256
+
257
+ # Add gap at the beginning if needed
258
+ if (segments[0]['start'] > 0):
259
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
260
+
261
+ for i in range(len(segments) - 1):
262
+ current_segment = segments[i]
263
+ next_segment = segments[i + 1]
264
+
265
+ delta = next_segment['start'] - current_segment['end']
266
+
267
+ # Expand if the gap actually exists
268
+ if (delta >= 0):
269
+ current_segment = current_segment.copy()
270
+ current_segment['expand_amount'] = delta
271
+ current_segment['end'] = next_segment['start']
272
+
273
+ result.append(current_segment)
274
+
275
+ # Add last segment
276
+ last_segment = segments[-1]
277
+ result.append(last_segment)
278
+
279
+ # Also include total duration if specified
280
+ if (total_duration is not None):
281
+ last_segment = result[-1]
282
+
283
+ if (last_segment['end'] < total_duration):
284
+ last_segment = last_segment.copy()
285
+ last_segment['end'] = total_duration
286
+ result[-1] = last_segment
287
+
288
+ return result
289
+
290
+ def fill_gaps(self, segments: List[Dict[str, Any]], total_duration: float, max_expand_size: float = None):
291
+ result = []
292
+
293
+ if len(segments) == 0:
294
+ return result
295
+
296
+ # Add gap at the beginning if needed
297
+ if (segments[0]['start'] > 0):
298
+ result.append({ 'start': 0, 'end': segments[0]['start'], 'gap': True } )
299
+
300
+ for i in range(len(segments) - 1):
301
+ expanded = False
302
+ current_segment = segments[i]
303
+ next_segment = segments[i + 1]
304
+
305
+ delta = next_segment['start'] - current_segment['end']
306
+
307
+ if (max_expand_size is not None and delta <= max_expand_size):
308
+ # Just expand the current segment
309
+ current_segment = current_segment.copy()
310
+ current_segment['expand_amount'] = delta
311
+ current_segment['end'] = next_segment['start']
312
+ expanded = True
313
+
314
+ result.append(current_segment)
315
+
316
+ # Add a gap to the next segment if needed
317
+ if (delta >= 0 and not expanded):
318
+ result.append({ 'start': current_segment['end'], 'end': next_segment['start'], 'gap': True } )
319
+
320
+ # Add last segment
321
+ last_segment = segments[-1]
322
+ result.append(last_segment)
323
+
324
+ # Also include total duration if specified
325
+ if (total_duration is not None):
326
+ last_segment = result[-1]
327
+
328
+ delta = total_duration - last_segment['end']
329
+
330
+ if (delta > 0):
331
+ if (max_expand_size is not None and delta <= max_expand_size):
332
+ # Expand the last segment
333
+ last_segment = last_segment.copy()
334
+ last_segment['expand_amount'] = delta
335
+ last_segment['end'] = total_duration
336
+ result[-1] = last_segment
337
+ else:
338
+ result.append({ 'start': last_segment['end'], 'end': total_duration, 'gap': True } )
339
+
340
+ return result
341
+
342
+ def adjust_timestamp(self, segments: Iterator[dict], adjust_seconds: float, max_source_time: float = None):
343
+ result = []
344
+
345
+ for segment in segments:
346
+ segment_start = float(segment['start'])
347
+ segment_end = float(segment['end'])
348
+
349
+ # Filter segments?
350
+ if (max_source_time is not None):
351
+ if (segment_start > max_source_time):
352
+ continue
353
+ segment_end = min(max_source_time, segment_end)
354
+
355
+ new_segment = segment.copy()
356
+
357
+ # Add to start and end
358
+ new_segment['start'] = segment_start + adjust_seconds
359
+ new_segment['end'] = segment_end + adjust_seconds
360
+ result.append(new_segment)
361
+ return result
362
+
363
+ def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
364
+ result = []
365
+
366
+ for entry in timestamps:
367
+ start = entry['start']
368
+ end = entry['end']
369
+
370
+ result.append({
371
+ 'start': start * factor,
372
+ 'end': end * factor
373
+ })
374
+ return result
375
+
376
+ class VadSileroTranscription(AbstractTranscription):
377
+ def __init__(self, sampling_rate: int = 16000):
378
+ super().__init__(sampling_rate=sampling_rate)
379
+
380
+ self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
381
+ (self.get_speech_timestamps, _, _, _, _) = utils
382
+
383
+
384
+ def get_transcribe_timestamps(self, audio: str, config: TranscriptionConfig):
385
+ audio_duration = get_audio_duration(audio)
386
+ result = []
387
+
388
+ # Divide procesisng of audio into chunks
389
+ chunk_start = 0.0
390
+
391
+ while (chunk_start < audio_duration):
392
+ chunk_duration = min(audio_duration - chunk_start, VAD_MAX_PROCESSING_CHUNK)
393
+
394
+ print("Processing VAD in chunk from {} to {}".format(format_timestamp(chunk_start), format_timestamp(chunk_start + chunk_duration)))
395
+ wav = self.get_audio_segment(audio, str(chunk_start), str(chunk_duration))
396
+
397
+ sample_timestamps = self.get_speech_timestamps(wav, self.model, sampling_rate=self.sampling_rate, threshold=SPEECH_TRESHOLD)
398
+ seconds_timestamps = self.multiply_timestamps(sample_timestamps, factor=1 / self.sampling_rate)
399
+ adjusted = self.adjust_timestamp(seconds_timestamps, adjust_seconds=chunk_start, max_source_time=chunk_start + chunk_duration)
400
+
401
+ #pprint(adjusted)
402
+
403
+ result.extend(adjusted)
404
+ chunk_start += chunk_duration
405
+
406
+ return result
407
+
408
+ # A very simple VAD that just marks every N seconds as speech
409
+ class VadPeriodicTranscription(AbstractTranscription):
410
+ def __init__(self, sampling_rate: int = 16000):
411
+ super().__init__(sampling_rate=sampling_rate)
412
+
413
+ def get_transcribe_timestamps(self, audio: str, config: PeriodicTranscriptionConfig):
414
+ # Get duration in seconds
415
+ audio_duration = get_audio_duration(audio)
416
+ result = []
417
+
418
+ # Generate a timestamp every N seconds
419
+ start_timestamp = 0
420
+
421
+ while (start_timestamp < audio_duration):
422
+ end_timestamp = min(start_timestamp + config.periodic_duration, audio_duration)
423
+ segment_duration = end_timestamp - start_timestamp
424
+
425
+ # Minimum duration is 1 second
426
+ if (segment_duration >= 1):
427
+ result.append( { 'start': start_timestamp, 'end': end_timestamp } )
428
+
429
+ start_timestamp = end_timestamp
430
+
431
+ return result
432
+
433
+ def get_audio_duration(file: str):
434
+ return float(ffmpeg.probe(file)["format"]["duration"])
435
+
436
+ def load_audio(file: str, sample_rate: int = 16000,
437
+ start_time: str = None, duration: str = None):
438
+ """
439
+ Open an audio file and read as mono waveform, resampling as necessary
440
+
441
+ Parameters
442
+ ----------
443
+ file: str
444
+ The audio file to open
445
+
446
+ sr: int
447
+ The sample rate to resample the audio if necessary
448
+
449
+ start_time: str
450
+ The start time, using the standard FFMPEG time duration syntax, or None to disable.
451
+
452
+ duration: str
453
+ The duration, using the standard FFMPEG time duration syntax, or None to disable.
454
+
455
+ Returns
456
+ -------
457
+ A NumPy array containing the audio waveform, in float32 dtype.
458
+ """
459
+ try:
460
+ inputArgs = {'threads': 0}
461
+
462
+ if (start_time is not None):
463
+ inputArgs['ss'] = start_time
464
+ if (duration is not None):
465
+ inputArgs['t'] = duration
466
+
467
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
468
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
469
+ out, _ = (
470
+ ffmpeg.input(file, **inputArgs)
471
+ .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=sample_rate)
472
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True)
473
+ )
474
+ except ffmpeg.Error as e:
475
+ raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}")
476
+
477
+ return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
tests/segments_test.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import unittest
3
+
4
+ sys.path.append('../whisper-webui')
5
+
6
+ from src.segments import merge_timestamps
7
+
8
+ class TestSegments(unittest.TestCase):
9
+ def __init__(self, *args, **kwargs):
10
+ super(TestSegments, self).__init__(*args, **kwargs)
11
+
12
+ def test_merge_segments(self):
13
+ segments = [
14
+ {'start': 10.0, 'end': 20.0},
15
+ {'start': 22.0, 'end': 27.0},
16
+ {'start': 31.0, 'end': 35.0},
17
+ {'start': 45.0, 'end': 60.0},
18
+ {'start': 61.0, 'end': 65.0},
19
+ {'start': 68.0, 'end': 98.0},
20
+ {'start': 100.0, 'end': 102.0},
21
+ {'start': 110.0, 'end': 112.0}
22
+ ]
23
+
24
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
25
+
26
+ self.assertListEqual(result, [
27
+ {'start': 9.0, 'end': 36.0},
28
+ {'start': 44.0, 'end': 66.0},
29
+ {'start': 67.0, 'end': 99.0},
30
+ {'start': 99.0, 'end': 103.0},
31
+ {'start': 109.0, 'end': 113.0}
32
+ ])
33
+
34
+ def test_overlap_next(self):
35
+ segments = [
36
+ {'start': 5.0, 'end': 39.182},
37
+ {'start': 39.986, 'end': 40.814}
38
+ ]
39
+
40
+ result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
41
+
42
+ self.assertListEqual(result, [
43
+ {'start': 4.0, 'end': 39.584},
44
+ {'start': 39.584, 'end': 41.814}
45
+ ])
46
+
47
+ if __name__ == '__main__':
48
+ unittest.main()
tests/vad_test.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ import unittest
3
+ import numpy as np
4
+ import sys
5
+
6
+ sys.path.append('../whisper-webui')
7
+
8
+ from src.vad import AbstractTranscription, VadSileroTranscription
9
+
10
+ class TestVad(unittest.TestCase):
11
+ def __init__(self, *args, **kwargs):
12
+ super(TestVad, self).__init__(*args, **kwargs)
13
+ self.transcribe_calls = []
14
+
15
+ def test_transcript(self):
16
+ mock = MockVadTranscription()
17
+
18
+ self.transcribe_calls.clear()
19
+ result = mock.transcribe("mock", lambda segment : self.transcribe_segments(segment))
20
+
21
+ self.assertListEqual(self.transcribe_calls, [
22
+ [30, 30],
23
+ [100, 100]
24
+ ])
25
+
26
+ self.assertListEqual(result['segments'],
27
+ [{'end': 50.0, 'start': 40.0, 'text': 'Hello world '},
28
+ {'end': 120.0, 'start': 110.0, 'text': 'Hello world '}]
29
+ )
30
+
31
+ def transcribe_segments(self, segment):
32
+ self.transcribe_calls.append(segment.tolist())
33
+
34
+ # Dummy text
35
+ return {
36
+ 'text': "Hello world ",
37
+ 'segments': [
38
+ {
39
+ "start": 10.0,
40
+ "end": 20.0,
41
+ "text": "Hello world "
42
+ }
43
+ ],
44
+ 'language': ""
45
+ }
46
+
47
+ class MockVadTranscription(AbstractTranscription):
48
+ def __init__(self):
49
+ super().__init__()
50
+
51
+ def get_audio_segment(self, str, start_time: str = None, duration: str = None):
52
+ start_time_seconds = float(start_time.removesuffix("s"))
53
+ duration_seconds = float(duration.removesuffix("s"))
54
+
55
+ # For mocking, this just returns a simple numppy array
56
+ return np.array([start_time_seconds, duration_seconds], dtype=np.float64)
57
+
58
+ def get_transcribe_timestamps(self, audio: str):
59
+ result = []
60
+
61
+ result.append( { 'start': 30, 'end': 60 } )
62
+ result.append( { 'start': 100, 'end': 200 } )
63
+ return result
64
+
65
+ if __name__ == '__main__':
66
+ unittest.main()