Spaces:
Configuration error
Configuration error
Fedir Zadniprovskyi
commited on
Commit
·
dc4f25f
1
Parent(s):
8ad4ca5
chore: fix ruff errors
Browse files- faster_whisper_server/asr.py +2 -2
- faster_whisper_server/audio.py +8 -6
- faster_whisper_server/config.py +4 -4
- faster_whisper_server/core.py +15 -35
- faster_whisper_server/gradio_app.py +11 -26
- faster_whisper_server/logger.py +1 -3
- faster_whisper_server/main.py +38 -79
- faster_whisper_server/server_models.py +11 -21
- faster_whisper_server/transcriber.py +6 -2
- pyproject.toml +24 -7
- tests/api_model_test.py +4 -6
- tests/app_test.py +9 -13
- tests/conftest.py +4 -7
- tests/sse_test.py +9 -19
faster_whisper_server/asr.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
import time
|
| 3 |
-
from typing import Iterable
|
| 4 |
|
| 5 |
from faster_whisper import transcribe
|
| 6 |
|
|
@@ -45,7 +45,7 @@ class FasterWhisperASR:
|
|
| 45 |
audio: Audio,
|
| 46 |
prompt: str | None = None,
|
| 47 |
) -> tuple[Transcription, transcribe.TranscriptionInfo]:
|
| 48 |
-
"""Wrapper around _transcribe so it can be used in async context"""
|
| 49 |
# is this the optimal way to execute a blocking call in an async context?
|
| 50 |
# TODO: verify performance when running inference on a CPU
|
| 51 |
return await asyncio.get_running_loop().run_in_executor(
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
from collections.abc import Iterable
|
| 3 |
import time
|
|
|
|
| 4 |
|
| 5 |
from faster_whisper import transcribe
|
| 6 |
|
|
|
|
| 45 |
audio: Audio,
|
| 46 |
prompt: str | None = None,
|
| 47 |
) -> tuple[Transcription, transcribe.TranscriptionInfo]:
|
| 48 |
+
"""Wrapper around _transcribe so it can be used in async context."""
|
| 49 |
# is this the optimal way to execute a blocking call in an async context?
|
| 50 |
# TODO: verify performance when running inference on a CPU
|
| 51 |
return await asyncio.get_running_loop().run_in_executor(
|
faster_whisper_server/audio.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
-
from typing import
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import soundfile as sf
|
| 8 |
-
from numpy.typing import NDArray
|
| 9 |
|
| 10 |
from faster_whisper_server.config import SAMPLES_PER_SECOND
|
| 11 |
from faster_whisper_server.logger import logger
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
|
| 15 |
audio_and_sample_rate = sf.read(
|
|
@@ -22,7 +26,7 @@ def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
|
|
| 22 |
endian="LITTLE",
|
| 23 |
)
|
| 24 |
audio = audio_and_sample_rate[0]
|
| 25 |
-
return audio #
|
| 26 |
|
| 27 |
|
| 28 |
class Audio:
|
|
@@ -78,9 +82,7 @@ class AudioStream(Audio):
|
|
| 78 |
self.modify_event.set()
|
| 79 |
logger.info("AudioStream closed")
|
| 80 |
|
| 81 |
-
async def chunks(
|
| 82 |
-
self, min_duration: float
|
| 83 |
-
) -> AsyncGenerator[NDArray[np.float32], None]:
|
| 84 |
i = 0.0 # end time of last chunk
|
| 85 |
while True:
|
| 86 |
await self.modify_event.wait()
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
+
from typing import TYPE_CHECKING, BinaryIO
|
| 5 |
|
| 6 |
import numpy as np
|
| 7 |
import soundfile as sf
|
|
|
|
| 8 |
|
| 9 |
from faster_whisper_server.config import SAMPLES_PER_SECOND
|
| 10 |
from faster_whisper_server.logger import logger
|
| 11 |
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from collections.abc import AsyncGenerator
|
| 14 |
+
|
| 15 |
+
from numpy.typing import NDArray
|
| 16 |
+
|
| 17 |
|
| 18 |
def audio_samples_from_file(file: BinaryIO) -> NDArray[np.float32]:
|
| 19 |
audio_and_sample_rate = sf.read(
|
|
|
|
| 26 |
endian="LITTLE",
|
| 27 |
)
|
| 28 |
audio = audio_and_sample_rate[0]
|
| 29 |
+
return audio # pyright: ignore[reportReturnType]
|
| 30 |
|
| 31 |
|
| 32 |
class Audio:
|
|
|
|
| 82 |
self.modify_event.set()
|
| 83 |
logger.info("AudioStream closed")
|
| 84 |
|
| 85 |
+
async def chunks(self, min_duration: float) -> AsyncGenerator[NDArray[np.float32], None]:
|
|
|
|
|
|
|
| 86 |
i = 0.0 # end time of last chunk
|
| 87 |
while True:
|
| 88 |
await self.modify_event.wait()
|
faster_whisper_server/config.py
CHANGED
|
@@ -15,7 +15,7 @@ class ResponseFormat(enum.StrEnum):
|
|
| 15 |
TEXT = "text"
|
| 16 |
JSON = "json"
|
| 17 |
VERBOSE_JSON = "verbose_json"
|
| 18 |
-
# NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response.
|
| 19 |
|
| 20 |
# VTT = "vtt" # TODO
|
| 21 |
# 1
|
|
@@ -185,8 +185,8 @@ class WhisperConfig(BaseModel):
|
|
| 185 |
|
| 186 |
|
| 187 |
class Config(BaseSettings):
|
| 188 |
-
"""
|
| 189 |
-
|
| 190 |
Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
|
| 191 |
To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
|
| 192 |
the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
|
|
@@ -208,7 +208,7 @@ class Config(BaseSettings):
|
|
| 208 |
max_inactivity_seconds: float = 5.0
|
| 209 |
"""
|
| 210 |
Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
|
| 211 |
-
"""
|
| 212 |
inactivity_window_seconds: float = 10.0
|
| 213 |
"""
|
| 214 |
Controls how many latest seconds of audio are being passed through VAD.
|
|
|
|
| 15 |
TEXT = "text"
|
| 16 |
JSON = "json"
|
| 17 |
VERBOSE_JSON = "verbose_json"
|
| 18 |
+
# NOTE: While inspecting outputs of these formats with `curl`, I noticed there's one or two "\n" inserted at the end of the response. # noqa: E501
|
| 19 |
|
| 20 |
# VTT = "vtt" # TODO
|
| 21 |
# 1
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
class Config(BaseSettings):
|
| 188 |
+
"""Configuration for the application. Values can be set via environment variables.
|
| 189 |
+
|
| 190 |
Pydantic will automatically handle mapping uppercased environment variables to the corresponding fields.
|
| 191 |
To populate nested, the environment should be prefixed with the nested field name and an underscore. For example,
|
| 192 |
the environment variable `LOG_LEVEL` will be mapped to `log_level`, `WHISPER_MODEL` to `whisper.model`, etc.
|
|
|
|
| 208 |
max_inactivity_seconds: float = 5.0
|
| 209 |
"""
|
| 210 |
Max allowed audio duration without any speech being detected before transcription is finilized and connection is closed.
|
| 211 |
+
""" # noqa: E501
|
| 212 |
inactivity_window_seconds: float = 10.0
|
| 213 |
"""
|
| 214 |
Controls how many latest seconds of audio are being passed through VAD.
|
faster_whisper_server/core.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# TODO: rename module
|
| 2 |
from __future__ import annotations
|
| 3 |
|
| 4 |
-
import re
|
| 5 |
from dataclasses import dataclass
|
|
|
|
| 6 |
|
| 7 |
from faster_whisper_server.config import config
|
| 8 |
|
|
@@ -18,10 +18,7 @@ class Segment:
|
|
| 18 |
def is_eos(self) -> bool:
|
| 19 |
if self.text.endswith("..."):
|
| 20 |
return False
|
| 21 |
-
for punctuation_symbol in ".?!"
|
| 22 |
-
if self.text.endswith(punctuation_symbol):
|
| 23 |
-
return True
|
| 24 |
-
return False
|
| 25 |
|
| 26 |
def offset(self, seconds: float) -> None:
|
| 27 |
self.start += seconds
|
|
@@ -36,11 +33,7 @@ class Word(Segment):
|
|
| 36 |
@classmethod
|
| 37 |
def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
|
| 38 |
i = 0
|
| 39 |
-
while (
|
| 40 |
-
i < len(a)
|
| 41 |
-
and i < len(b)
|
| 42 |
-
and canonicalize_word(a[i].text) == canonicalize_word(b[i].text)
|
| 43 |
-
):
|
| 44 |
i += 1
|
| 45 |
return a[:i]
|
| 46 |
|
|
@@ -67,9 +60,7 @@ class Transcription:
|
|
| 67 |
return self.end - self.start
|
| 68 |
|
| 69 |
def after(self, seconds: float) -> Transcription:
|
| 70 |
-
return Transcription(
|
| 71 |
-
words=[word for word in self.words if word.start > seconds]
|
| 72 |
-
)
|
| 73 |
|
| 74 |
def extend(self, words: list[Word]) -> None:
|
| 75 |
self._ensure_no_word_overlap(words)
|
|
@@ -77,21 +68,16 @@ class Transcription:
|
|
| 77 |
|
| 78 |
def _ensure_no_word_overlap(self, words: list[Word]) -> None:
|
| 79 |
if len(self.words) > 0 and len(words) > 0:
|
| 80 |
-
if
|
| 81 |
-
words[0].start + config.word_timestamp_error_margin
|
| 82 |
-
<= self.words[-1].end
|
| 83 |
-
):
|
| 84 |
raise ValueError(
|
| 85 |
-
f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}"
|
| 86 |
)
|
| 87 |
for i in range(1, len(words)):
|
| 88 |
if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end:
|
| 89 |
-
raise ValueError(
|
| 90 |
-
f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}"
|
| 91 |
-
)
|
| 92 |
|
| 93 |
|
| 94 |
-
def test_segment_is_eos():
|
| 95 |
assert not Segment("Hello").is_eos
|
| 96 |
assert not Segment("Hello...").is_eos
|
| 97 |
assert Segment("Hello.").is_eos
|
|
@@ -117,16 +103,14 @@ def to_full_sentences(words: list[Word]) -> list[Segment]:
|
|
| 117 |
return sentences
|
| 118 |
|
| 119 |
|
| 120 |
-
def tests_to_full_sentences():
|
| 121 |
assert to_full_sentences([]) == []
|
| 122 |
assert to_full_sentences([Word(text="Hello")]) == []
|
| 123 |
assert to_full_sentences([Word(text="Hello..."), Word(" world")]) == []
|
| 124 |
-
assert to_full_sentences([Word(text="Hello..."), Word(" world.")]) == [
|
|
|
|
| 125 |
Segment(text="Hello... world.")
|
| 126 |
]
|
| 127 |
-
assert to_full_sentences(
|
| 128 |
-
[Word(text="Hello..."), Word(" world."), Word(" How")]
|
| 129 |
-
) == [Segment(text="Hello... world.")]
|
| 130 |
|
| 131 |
|
| 132 |
def to_text(words: list[Word]) -> str:
|
|
@@ -144,7 +128,7 @@ def canonicalize_word(text: str) -> str:
|
|
| 144 |
return text.lower().strip().strip(".,?!")
|
| 145 |
|
| 146 |
|
| 147 |
-
def test_canonicalize_word():
|
| 148 |
assert canonicalize_word("ABC") == "abc"
|
| 149 |
assert canonicalize_word("...ABC?") == "abc"
|
| 150 |
assert canonicalize_word("... AbC ...") == "abc"
|
|
@@ -152,16 +136,12 @@ def test_canonicalize_word():
|
|
| 152 |
|
| 153 |
def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
|
| 154 |
i = 0
|
| 155 |
-
while (
|
| 156 |
-
i < len(a)
|
| 157 |
-
and i < len(b)
|
| 158 |
-
and canonicalize_word(a[i].text) == canonicalize_word(b[i].text)
|
| 159 |
-
):
|
| 160 |
i += 1
|
| 161 |
return a[:i]
|
| 162 |
|
| 163 |
|
| 164 |
-
def test_common_prefix():
|
| 165 |
def word(text: str) -> Word:
|
| 166 |
return Word(text=text, start=0.0, end=0.0, probability=0.0)
|
| 167 |
|
|
@@ -194,7 +174,7 @@ def test_common_prefix():
|
|
| 194 |
assert common_prefix(a, b) == []
|
| 195 |
|
| 196 |
|
| 197 |
-
def test_common_prefix_and_canonicalization():
|
| 198 |
def word(text: str) -> Word:
|
| 199 |
return Word(text=text, start=0.0, end=0.0, probability=0.0)
|
| 200 |
|
|
|
|
| 1 |
# TODO: rename module
|
| 2 |
from __future__ import annotations
|
| 3 |
|
|
|
|
| 4 |
from dataclasses import dataclass
|
| 5 |
+
import re
|
| 6 |
|
| 7 |
from faster_whisper_server.config import config
|
| 8 |
|
|
|
|
| 18 |
def is_eos(self) -> bool:
|
| 19 |
if self.text.endswith("..."):
|
| 20 |
return False
|
| 21 |
+
return any(self.text.endswith(punctuation_symbol) for punctuation_symbol in ".?!")
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
def offset(self, seconds: float) -> None:
|
| 24 |
self.start += seconds
|
|
|
|
| 33 |
@classmethod
|
| 34 |
def common_prefix(cls, a: list[Word], b: list[Word]) -> list[Word]:
|
| 35 |
i = 0
|
| 36 |
+
while i < len(a) and i < len(b) and canonicalize_word(a[i].text) == canonicalize_word(b[i].text):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
i += 1
|
| 38 |
return a[:i]
|
| 39 |
|
|
|
|
| 60 |
return self.end - self.start
|
| 61 |
|
| 62 |
def after(self, seconds: float) -> Transcription:
|
| 63 |
+
return Transcription(words=[word for word in self.words if word.start > seconds])
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def extend(self, words: list[Word]) -> None:
|
| 66 |
self._ensure_no_word_overlap(words)
|
|
|
|
| 68 |
|
| 69 |
def _ensure_no_word_overlap(self, words: list[Word]) -> None:
|
| 70 |
if len(self.words) > 0 and len(words) > 0:
|
| 71 |
+
if words[0].start + config.word_timestamp_error_margin <= self.words[-1].end:
|
|
|
|
|
|
|
|
|
|
| 72 |
raise ValueError(
|
| 73 |
+
f"Words overlap: {self.words[-1]} and {words[0]}. Error margin: {config.word_timestamp_error_margin}" # noqa: E501
|
| 74 |
)
|
| 75 |
for i in range(1, len(words)):
|
| 76 |
if words[i].start + config.word_timestamp_error_margin <= words[i - 1].end:
|
| 77 |
+
raise ValueError(f"Words overlap: {words[i - 1]} and {words[i]}. All words: {words}")
|
|
|
|
|
|
|
| 78 |
|
| 79 |
|
| 80 |
+
def test_segment_is_eos() -> None:
|
| 81 |
assert not Segment("Hello").is_eos
|
| 82 |
assert not Segment("Hello...").is_eos
|
| 83 |
assert Segment("Hello.").is_eos
|
|
|
|
| 103 |
return sentences
|
| 104 |
|
| 105 |
|
| 106 |
+
def tests_to_full_sentences() -> None:
|
| 107 |
assert to_full_sentences([]) == []
|
| 108 |
assert to_full_sentences([Word(text="Hello")]) == []
|
| 109 |
assert to_full_sentences([Word(text="Hello..."), Word(" world")]) == []
|
| 110 |
+
assert to_full_sentences([Word(text="Hello..."), Word(" world.")]) == [Segment(text="Hello... world.")]
|
| 111 |
+
assert to_full_sentences([Word(text="Hello..."), Word(" world."), Word(" How")]) == [
|
| 112 |
Segment(text="Hello... world.")
|
| 113 |
]
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
|
| 116 |
def to_text(words: list[Word]) -> str:
|
|
|
|
| 128 |
return text.lower().strip().strip(".,?!")
|
| 129 |
|
| 130 |
|
| 131 |
+
def test_canonicalize_word() -> None:
|
| 132 |
assert canonicalize_word("ABC") == "abc"
|
| 133 |
assert canonicalize_word("...ABC?") == "abc"
|
| 134 |
assert canonicalize_word("... AbC ...") == "abc"
|
|
|
|
| 136 |
|
| 137 |
def common_prefix(a: list[Word], b: list[Word]) -> list[Word]:
|
| 138 |
i = 0
|
| 139 |
+
while i < len(a) and i < len(b) and canonicalize_word(a[i].text) == canonicalize_word(b[i].text):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
i += 1
|
| 141 |
return a[:i]
|
| 142 |
|
| 143 |
|
| 144 |
+
def test_common_prefix() -> None:
|
| 145 |
def word(text: str) -> Word:
|
| 146 |
return Word(text=text, start=0.0, end=0.0, probability=0.0)
|
| 147 |
|
|
|
|
| 174 |
assert common_prefix(a, b) == []
|
| 175 |
|
| 176 |
|
| 177 |
+
def test_common_prefix_and_canonicalization() -> None:
|
| 178 |
def word(text: str) -> Word:
|
| 179 |
return Word(text=text, start=0.0, end=0.0, probability=0.0)
|
| 180 |
|
faster_whisper_server/gradio_app.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
from typing import Generator
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
import httpx
|
|
@@ -13,26 +13,20 @@ TRANSLATION_ENDPOINT = "/v1/audio/translations"
|
|
| 13 |
|
| 14 |
def create_gradio_demo(config: Config) -> gr.Blocks:
|
| 15 |
host = os.getenv("UVICORN_HOST", "0.0.0.0")
|
| 16 |
-
port = os.getenv("UVICORN_PORT", 8000)
|
| 17 |
# NOTE: worth looking into generated clients
|
| 18 |
http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
|
| 19 |
|
| 20 |
-
def handler(
|
| 21 |
-
file_path: str, model: str, task: Task, temperature: float, stream: bool
|
| 22 |
-
) -> Generator[str, None, None]:
|
| 23 |
if stream:
|
| 24 |
previous_transcription = ""
|
| 25 |
-
for transcription in transcribe_audio_streaming(
|
| 26 |
-
file_path, task, temperature, model
|
| 27 |
-
):
|
| 28 |
previous_transcription += transcription
|
| 29 |
yield previous_transcription
|
| 30 |
else:
|
| 31 |
yield transcribe_audio(file_path, task, temperature, model)
|
| 32 |
|
| 33 |
-
def transcribe_audio(
|
| 34 |
-
file_path: str, task: Task, temperature: float, model: str
|
| 35 |
-
) -> str:
|
| 36 |
if task == Task.TRANSCRIBE:
|
| 37 |
endpoint = TRANSCRIPTION_ENDPOINT
|
| 38 |
elif task == Task.TRANSLATE:
|
|
@@ -65,11 +59,7 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
|
|
| 65 |
"stream": True,
|
| 66 |
},
|
| 67 |
}
|
| 68 |
-
endpoint =
|
| 69 |
-
TRANSCRIPTION_ENDPOINT
|
| 70 |
-
if task == Task.TRANSCRIBE
|
| 71 |
-
else TRANSLATION_ENDPOINT
|
| 72 |
-
)
|
| 73 |
with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
|
| 74 |
for event in event_source.iter_sse():
|
| 75 |
yield event.data
|
|
@@ -79,18 +69,15 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
|
|
| 79 |
res_data = res.json()
|
| 80 |
models: list[str] = [model["id"] for model in res_data]
|
| 81 |
assert config.whisper.model in models
|
| 82 |
-
recommended_models =
|
| 83 |
-
model for model in models if model.startswith("Systran")
|
| 84 |
-
)
|
| 85 |
other_models = [model for model in models if model not in recommended_models]
|
| 86 |
models = list(recommended_models) + other_models
|
| 87 |
-
|
| 88 |
# no idea why it's complaining
|
| 89 |
-
choices=models, #
|
| 90 |
label="Model",
|
| 91 |
value=config.whisper.model,
|
| 92 |
)
|
| 93 |
-
return model_dropdown
|
| 94 |
|
| 95 |
model_dropdown = gr.Dropdown(
|
| 96 |
choices=[config.whisper.model],
|
|
@@ -102,13 +89,11 @@ def create_gradio_demo(config: Config) -> gr.Blocks:
|
|
| 102 |
label="Task",
|
| 103 |
value=Task.TRANSCRIBE,
|
| 104 |
)
|
| 105 |
-
temperature_slider = gr.Slider(
|
| 106 |
-
minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0
|
| 107 |
-
)
|
| 108 |
stream_checkbox = gr.Checkbox(label="Stream", value=True)
|
| 109 |
with gr.Interface(
|
| 110 |
title="Whisper Playground",
|
| 111 |
-
description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""",
|
| 112 |
inputs=[
|
| 113 |
gr.Audio(type="filepath"),
|
| 114 |
model_dropdown,
|
|
|
|
| 1 |
+
from collections.abc import Generator
|
| 2 |
import os
|
|
|
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
import httpx
|
|
|
|
| 13 |
|
| 14 |
def create_gradio_demo(config: Config) -> gr.Blocks:
|
| 15 |
host = os.getenv("UVICORN_HOST", "0.0.0.0")
|
| 16 |
+
port = int(os.getenv("UVICORN_PORT", "8000"))
|
| 17 |
# NOTE: worth looking into generated clients
|
| 18 |
http_client = httpx.Client(base_url=f"http://{host}:{port}", timeout=None)
|
| 19 |
|
| 20 |
+
def handler(file_path: str, model: str, task: Task, temperature: float, stream: bool) -> Generator[str, None, None]:
|
|
|
|
|
|
|
| 21 |
if stream:
|
| 22 |
previous_transcription = ""
|
| 23 |
+
for transcription in transcribe_audio_streaming(file_path, task, temperature, model):
|
|
|
|
|
|
|
| 24 |
previous_transcription += transcription
|
| 25 |
yield previous_transcription
|
| 26 |
else:
|
| 27 |
yield transcribe_audio(file_path, task, temperature, model)
|
| 28 |
|
| 29 |
+
def transcribe_audio(file_path: str, task: Task, temperature: float, model: str) -> str:
|
|
|
|
|
|
|
| 30 |
if task == Task.TRANSCRIBE:
|
| 31 |
endpoint = TRANSCRIPTION_ENDPOINT
|
| 32 |
elif task == Task.TRANSLATE:
|
|
|
|
| 59 |
"stream": True,
|
| 60 |
},
|
| 61 |
}
|
| 62 |
+
endpoint = TRANSCRIPTION_ENDPOINT if task == Task.TRANSCRIBE else TRANSLATION_ENDPOINT
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
with connect_sse(http_client, "POST", endpoint, **kwargs) as event_source:
|
| 64 |
for event in event_source.iter_sse():
|
| 65 |
yield event.data
|
|
|
|
| 69 |
res_data = res.json()
|
| 70 |
models: list[str] = [model["id"] for model in res_data]
|
| 71 |
assert config.whisper.model in models
|
| 72 |
+
recommended_models = {model for model in models if model.startswith("Systran")}
|
|
|
|
|
|
|
| 73 |
other_models = [model for model in models if model not in recommended_models]
|
| 74 |
models = list(recommended_models) + other_models
|
| 75 |
+
return gr.Dropdown(
|
| 76 |
# no idea why it's complaining
|
| 77 |
+
choices=models, # pyright: ignore[reportArgumentType]
|
| 78 |
label="Model",
|
| 79 |
value=config.whisper.model,
|
| 80 |
)
|
|
|
|
| 81 |
|
| 82 |
model_dropdown = gr.Dropdown(
|
| 83 |
choices=[config.whisper.model],
|
|
|
|
| 89 |
label="Task",
|
| 90 |
value=Task.TRANSCRIBE,
|
| 91 |
)
|
| 92 |
+
temperature_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, label="Temperature", value=0.0)
|
|
|
|
|
|
|
| 93 |
stream_checkbox = gr.Checkbox(label="Stream", value=True)
|
| 94 |
with gr.Interface(
|
| 95 |
title="Whisper Playground",
|
| 96 |
+
description="""Consider supporting the project by starring the <a href="https://github.com/fedirz/faster-whisper-server">repository on GitHub</a>.""", # noqa: E501
|
| 97 |
inputs=[
|
| 98 |
gr.Audio(type="filepath"),
|
| 99 |
model_dropdown,
|
faster_whisper_server/logger.py
CHANGED
|
@@ -8,6 +8,4 @@ root_logger = logging.getLogger()
|
|
| 8 |
root_logger.setLevel(logging.CRITICAL)
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
logger.setLevel(config.log_level.upper())
|
| 11 |
-
logging.basicConfig(
|
| 12 |
-
format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s"
|
| 13 |
-
)
|
|
|
|
| 8 |
root_logger.setLevel(logging.CRITICAL)
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
logger.setLevel(config.log_level.upper())
|
| 11 |
+
logging.basicConfig(format="%(asctime)s:%(levelname)s:%(name)s:%(funcName)s:%(message)s")
|
|
|
|
|
|
faster_whisper_server/main.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
-
import
|
| 5 |
from io import BytesIO
|
| 6 |
-
|
|
|
|
| 7 |
|
| 8 |
-
import gradio as gr
|
| 9 |
-
import huggingface_hub
|
| 10 |
from fastapi import (
|
| 11 |
FastAPI,
|
| 12 |
Form,
|
|
@@ -21,9 +20,9 @@ from fastapi import (
|
|
| 21 |
from fastapi.responses import StreamingResponse
|
| 22 |
from fastapi.websockets import WebSocketState
|
| 23 |
from faster_whisper import WhisperModel
|
| 24 |
-
from faster_whisper.transcribe import Segment, TranscriptionInfo
|
| 25 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
| 26 |
-
|
|
|
|
| 27 |
from pydantic import AfterValidator
|
| 28 |
|
| 29 |
from faster_whisper_server import utils
|
|
@@ -45,6 +44,12 @@ from faster_whisper_server.server_models import (
|
|
| 45 |
)
|
| 46 |
from faster_whisper_server.transcriber import audio_transcriber
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
| 49 |
|
| 50 |
|
|
@@ -54,9 +59,7 @@ def load_model(model_name: str) -> WhisperModel:
|
|
| 54 |
return loaded_models[model_name]
|
| 55 |
if len(loaded_models) >= config.max_models:
|
| 56 |
oldest_model_name = next(iter(loaded_models))
|
| 57 |
-
logger.info(
|
| 58 |
-
f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}"
|
| 59 |
-
)
|
| 60 |
del loaded_models[oldest_model_name]
|
| 61 |
logger.debug(f"Loading {model_name}...")
|
| 62 |
start = time.perf_counter()
|
|
@@ -67,7 +70,7 @@ def load_model(model_name: str) -> WhisperModel:
|
|
| 67 |
compute_type=config.whisper.compute_type,
|
| 68 |
)
|
| 69 |
logger.info(
|
| 70 |
-
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference."
|
| 71 |
)
|
| 72 |
loaded_models[model_name] = whisper
|
| 73 |
return whisper
|
|
@@ -102,9 +105,7 @@ def get_models() -> list[ModelObject]:
|
|
| 102 |
def get_model(
|
| 103 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
|
| 104 |
) -> ModelObject:
|
| 105 |
-
models = list(
|
| 106 |
-
huggingface_hub.list_models(model_name=model_name, library="ctranslate2")
|
| 107 |
-
)
|
| 108 |
if len(models) == 0:
|
| 109 |
raise HTTPException(status_code=404, detail="Model doesn't exists")
|
| 110 |
exact_match: ModelInfo | None = None
|
|
@@ -132,14 +133,12 @@ def segments_to_response(
|
|
| 132 |
response_format: ResponseFormat,
|
| 133 |
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
|
| 134 |
segments = list(segments)
|
| 135 |
-
if response_format == ResponseFormat.TEXT:
|
| 136 |
return utils.segments_text(segments)
|
| 137 |
elif response_format == ResponseFormat.JSON:
|
| 138 |
return TranscriptionJsonResponse.from_segments(segments)
|
| 139 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
| 140 |
-
return TranscriptionVerboseJsonResponse.from_segments(
|
| 141 |
-
segments, transcription_info
|
| 142 |
-
)
|
| 143 |
|
| 144 |
|
| 145 |
def format_as_sse(data: str) -> str:
|
|
@@ -156,26 +155,21 @@ def segments_to_streaming_response(
|
|
| 156 |
if response_format == ResponseFormat.TEXT:
|
| 157 |
data = segment.text
|
| 158 |
elif response_format == ResponseFormat.JSON:
|
| 159 |
-
data = TranscriptionJsonResponse.from_segments(
|
| 160 |
-
[segment]
|
| 161 |
-
).model_dump_json()
|
| 162 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
| 163 |
-
data = TranscriptionVerboseJsonResponse.from_segment(
|
| 164 |
-
segment, transcription_info
|
| 165 |
-
).model_dump_json()
|
| 166 |
yield format_as_sse(data)
|
| 167 |
|
| 168 |
return StreamingResponse(segment_responses(), media_type="text/event-stream")
|
| 169 |
|
| 170 |
|
| 171 |
def handle_default_openai_model(model_name: str) -> str:
|
| 172 |
-
"""
|
|
|
|
| 173 |
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
|
| 174 |
"""
|
| 175 |
if model_name == "whisper-1":
|
| 176 |
-
logger.info(
|
| 177 |
-
f"{model_name} is not a valid model name. Using {config.whisper.model} instead."
|
| 178 |
-
)
|
| 179 |
return config.whisper.model
|
| 180 |
return model_name
|
| 181 |
|
|
@@ -194,12 +188,7 @@ def translate_file(
|
|
| 194 |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
|
| 195 |
temperature: Annotated[float, Form()] = 0.0,
|
| 196 |
stream: Annotated[bool, Form()] = False,
|
| 197 |
-
) ->
|
| 198 |
-
str
|
| 199 |
-
| TranscriptionJsonResponse
|
| 200 |
-
| TranscriptionVerboseJsonResponse
|
| 201 |
-
| StreamingResponse
|
| 202 |
-
):
|
| 203 |
whisper = load_model(model)
|
| 204 |
segments, transcription_info = whisper.transcribe(
|
| 205 |
file.file,
|
|
@@ -210,9 +199,7 @@ def translate_file(
|
|
| 210 |
)
|
| 211 |
|
| 212 |
if stream:
|
| 213 |
-
return segments_to_streaming_response(
|
| 214 |
-
segments, transcription_info, response_format
|
| 215 |
-
)
|
| 216 |
else:
|
| 217 |
return segments_to_response(segments, transcription_info, response_format)
|
| 218 |
|
|
@@ -231,16 +218,11 @@ def transcribe_file(
|
|
| 231 |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
|
| 232 |
temperature: Annotated[float, Form()] = 0.0,
|
| 233 |
timestamp_granularities: Annotated[
|
| 234 |
-
list[Literal["segment"
|
| 235 |
Form(alias="timestamp_granularities[]"),
|
| 236 |
] = ["segment"],
|
| 237 |
stream: Annotated[bool, Form()] = False,
|
| 238 |
-
) ->
|
| 239 |
-
str
|
| 240 |
-
| TranscriptionJsonResponse
|
| 241 |
-
| TranscriptionVerboseJsonResponse
|
| 242 |
-
| StreamingResponse
|
| 243 |
-
):
|
| 244 |
whisper = load_model(model)
|
| 245 |
segments, transcription_info = whisper.transcribe(
|
| 246 |
file.file,
|
|
@@ -253,9 +235,7 @@ def transcribe_file(
|
|
| 253 |
)
|
| 254 |
|
| 255 |
if stream:
|
| 256 |
-
return segments_to_streaming_response(
|
| 257 |
-
segments, transcription_info, response_format
|
| 258 |
-
)
|
| 259 |
else:
|
| 260 |
return segments_to_response(segments, transcription_info, response_format)
|
| 261 |
|
|
@@ -263,39 +243,28 @@ def transcribe_file(
|
|
| 263 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
| 264 |
try:
|
| 265 |
while True:
|
| 266 |
-
bytes_ = await asyncio.wait_for(
|
| 267 |
-
ws.receive_bytes(), timeout=config.max_no_data_seconds
|
| 268 |
-
)
|
| 269 |
logger.debug(f"Received {len(bytes_)} bytes of audio data")
|
| 270 |
audio_samples = audio_samples_from_file(BytesIO(bytes_))
|
| 271 |
audio_stream.extend(audio_samples)
|
| 272 |
if audio_stream.duration - config.inactivity_window_seconds >= 0:
|
| 273 |
-
audio = audio_stream.after(
|
| 274 |
-
audio_stream.duration - config.inactivity_window_seconds
|
| 275 |
-
)
|
| 276 |
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
|
| 277 |
# NOTE: This is a synchronous operation that runs every time new data is received.
|
| 278 |
-
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato.
|
| 279 |
timestamps = get_speech_timestamps(audio.data, vad_opts)
|
| 280 |
if len(timestamps) == 0:
|
| 281 |
-
logger.info(
|
| 282 |
-
f"No speech detected in the last {config.inactivity_window_seconds} seconds."
|
| 283 |
-
)
|
| 284 |
break
|
| 285 |
elif (
|
| 286 |
# last speech end time
|
| 287 |
-
config.inactivity_window_seconds
|
| 288 |
-
- timestamps[-1]["end"] / SAMPLES_PER_SECOND
|
| 289 |
>= config.max_inactivity_seconds
|
| 290 |
):
|
| 291 |
-
logger.info(
|
| 292 |
-
f"Not enough speech in the last {config.inactivity_window_seconds} seconds."
|
| 293 |
-
)
|
| 294 |
break
|
| 295 |
-
except
|
| 296 |
-
logger.info(
|
| 297 |
-
f"No data received in {config.max_no_data_seconds} seconds. Closing the connection."
|
| 298 |
-
)
|
| 299 |
except WebSocketDisconnect as e:
|
| 300 |
logger.info(f"Client disconnected: {e}")
|
| 301 |
audio_stream.close()
|
|
@@ -306,9 +275,7 @@ async def transcribe_stream(
|
|
| 306 |
ws: WebSocket,
|
| 307 |
model: Annotated[ModelName, Query()] = config.whisper.model,
|
| 308 |
language: Annotated[Language | None, Query()] = config.default_language,
|
| 309 |
-
response_format: Annotated[
|
| 310 |
-
ResponseFormat, Query()
|
| 311 |
-
] = config.default_response_format,
|
| 312 |
temperature: Annotated[float, Query()] = 0.0,
|
| 313 |
) -> None:
|
| 314 |
await ws.accept()
|
|
@@ -331,19 +298,11 @@ async def transcribe_stream(
|
|
| 331 |
if response_format == ResponseFormat.TEXT:
|
| 332 |
await ws.send_text(transcription.text)
|
| 333 |
elif response_format == ResponseFormat.JSON:
|
| 334 |
-
await ws.send_json(
|
| 335 |
-
TranscriptionJsonResponse.from_transcription(
|
| 336 |
-
transcription
|
| 337 |
-
).model_dump()
|
| 338 |
-
)
|
| 339 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
| 340 |
-
await ws.send_json(
|
| 341 |
-
TranscriptionVerboseJsonResponse.from_transcription(
|
| 342 |
-
transcription
|
| 343 |
-
).model_dump()
|
| 344 |
-
)
|
| 345 |
|
| 346 |
-
if
|
| 347 |
logger.info("Closing the connection.")
|
| 348 |
await ws.close()
|
| 349 |
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
+
from collections import OrderedDict
|
| 5 |
from io import BytesIO
|
| 6 |
+
import time
|
| 7 |
+
from typing import TYPE_CHECKING, Annotated, Literal
|
| 8 |
|
|
|
|
|
|
|
| 9 |
from fastapi import (
|
| 10 |
FastAPI,
|
| 11 |
Form,
|
|
|
|
| 20 |
from fastapi.responses import StreamingResponse
|
| 21 |
from fastapi.websockets import WebSocketState
|
| 22 |
from faster_whisper import WhisperModel
|
|
|
|
| 23 |
from faster_whisper.vad import VadOptions, get_speech_timestamps
|
| 24 |
+
import gradio as gr
|
| 25 |
+
import huggingface_hub
|
| 26 |
from pydantic import AfterValidator
|
| 27 |
|
| 28 |
from faster_whisper_server import utils
|
|
|
|
| 44 |
)
|
| 45 |
from faster_whisper_server.transcriber import audio_transcriber
|
| 46 |
|
| 47 |
+
if TYPE_CHECKING:
|
| 48 |
+
from collections.abc import Generator, Iterable
|
| 49 |
+
|
| 50 |
+
from faster_whisper.transcribe import Segment, TranscriptionInfo
|
| 51 |
+
from huggingface_hub.hf_api import ModelInfo
|
| 52 |
+
|
| 53 |
loaded_models: OrderedDict[str, WhisperModel] = OrderedDict()
|
| 54 |
|
| 55 |
|
|
|
|
| 59 |
return loaded_models[model_name]
|
| 60 |
if len(loaded_models) >= config.max_models:
|
| 61 |
oldest_model_name = next(iter(loaded_models))
|
| 62 |
+
logger.info(f"Max models ({config.max_models}) reached. Unloading the oldest model: {oldest_model_name}")
|
|
|
|
|
|
|
| 63 |
del loaded_models[oldest_model_name]
|
| 64 |
logger.debug(f"Loading {model_name}...")
|
| 65 |
start = time.perf_counter()
|
|
|
|
| 70 |
compute_type=config.whisper.compute_type,
|
| 71 |
)
|
| 72 |
logger.info(
|
| 73 |
+
f"Loaded {model_name} loaded in {time.perf_counter() - start:.2f} seconds. {config.whisper.inference_device}({config.whisper.compute_type}) will be used for inference." # noqa: E501
|
| 74 |
)
|
| 75 |
loaded_models[model_name] = whisper
|
| 76 |
return whisper
|
|
|
|
| 105 |
def get_model(
|
| 106 |
model_name: Annotated[str, Path(example="Systran/faster-distil-whisper-large-v3")],
|
| 107 |
) -> ModelObject:
|
| 108 |
+
models = list(huggingface_hub.list_models(model_name=model_name, library="ctranslate2"))
|
|
|
|
|
|
|
| 109 |
if len(models) == 0:
|
| 110 |
raise HTTPException(status_code=404, detail="Model doesn't exists")
|
| 111 |
exact_match: ModelInfo | None = None
|
|
|
|
| 133 |
response_format: ResponseFormat,
|
| 134 |
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse:
|
| 135 |
segments = list(segments)
|
| 136 |
+
if response_format == ResponseFormat.TEXT: # noqa: RET503
|
| 137 |
return utils.segments_text(segments)
|
| 138 |
elif response_format == ResponseFormat.JSON:
|
| 139 |
return TranscriptionJsonResponse.from_segments(segments)
|
| 140 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
| 141 |
+
return TranscriptionVerboseJsonResponse.from_segments(segments, transcription_info)
|
|
|
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
def format_as_sse(data: str) -> str:
|
|
|
|
| 155 |
if response_format == ResponseFormat.TEXT:
|
| 156 |
data = segment.text
|
| 157 |
elif response_format == ResponseFormat.JSON:
|
| 158 |
+
data = TranscriptionJsonResponse.from_segments([segment]).model_dump_json()
|
|
|
|
|
|
|
| 159 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
| 160 |
+
data = TranscriptionVerboseJsonResponse.from_segment(segment, transcription_info).model_dump_json()
|
|
|
|
|
|
|
| 161 |
yield format_as_sse(data)
|
| 162 |
|
| 163 |
return StreamingResponse(segment_responses(), media_type="text/event-stream")
|
| 164 |
|
| 165 |
|
| 166 |
def handle_default_openai_model(model_name: str) -> str:
|
| 167 |
+
"""Exists because some callers may not be able override the default("whisper-1") model name.
|
| 168 |
+
|
| 169 |
For example, https://github.com/open-webui/open-webui/issues/2248#issuecomment-2162997623.
|
| 170 |
"""
|
| 171 |
if model_name == "whisper-1":
|
| 172 |
+
logger.info(f"{model_name} is not a valid model name. Using {config.whisper.model} instead.")
|
|
|
|
|
|
|
| 173 |
return config.whisper.model
|
| 174 |
return model_name
|
| 175 |
|
|
|
|
| 188 |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
|
| 189 |
temperature: Annotated[float, Form()] = 0.0,
|
| 190 |
stream: Annotated[bool, Form()] = False,
|
| 191 |
+
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
whisper = load_model(model)
|
| 193 |
segments, transcription_info = whisper.transcribe(
|
| 194 |
file.file,
|
|
|
|
| 199 |
)
|
| 200 |
|
| 201 |
if stream:
|
| 202 |
+
return segments_to_streaming_response(segments, transcription_info, response_format)
|
|
|
|
|
|
|
| 203 |
else:
|
| 204 |
return segments_to_response(segments, transcription_info, response_format)
|
| 205 |
|
|
|
|
| 218 |
response_format: Annotated[ResponseFormat, Form()] = config.default_response_format,
|
| 219 |
temperature: Annotated[float, Form()] = 0.0,
|
| 220 |
timestamp_granularities: Annotated[
|
| 221 |
+
list[Literal["segment", "word"]],
|
| 222 |
Form(alias="timestamp_granularities[]"),
|
| 223 |
] = ["segment"],
|
| 224 |
stream: Annotated[bool, Form()] = False,
|
| 225 |
+
) -> str | TranscriptionJsonResponse | TranscriptionVerboseJsonResponse | StreamingResponse:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
whisper = load_model(model)
|
| 227 |
segments, transcription_info = whisper.transcribe(
|
| 228 |
file.file,
|
|
|
|
| 235 |
)
|
| 236 |
|
| 237 |
if stream:
|
| 238 |
+
return segments_to_streaming_response(segments, transcription_info, response_format)
|
|
|
|
|
|
|
| 239 |
else:
|
| 240 |
return segments_to_response(segments, transcription_info, response_format)
|
| 241 |
|
|
|
|
| 243 |
async def audio_receiver(ws: WebSocket, audio_stream: AudioStream) -> None:
|
| 244 |
try:
|
| 245 |
while True:
|
| 246 |
+
bytes_ = await asyncio.wait_for(ws.receive_bytes(), timeout=config.max_no_data_seconds)
|
|
|
|
|
|
|
| 247 |
logger.debug(f"Received {len(bytes_)} bytes of audio data")
|
| 248 |
audio_samples = audio_samples_from_file(BytesIO(bytes_))
|
| 249 |
audio_stream.extend(audio_samples)
|
| 250 |
if audio_stream.duration - config.inactivity_window_seconds >= 0:
|
| 251 |
+
audio = audio_stream.after(audio_stream.duration - config.inactivity_window_seconds)
|
|
|
|
|
|
|
| 252 |
vad_opts = VadOptions(min_silence_duration_ms=500, speech_pad_ms=0)
|
| 253 |
# NOTE: This is a synchronous operation that runs every time new data is received.
|
| 254 |
+
# This shouldn't be an issue unless data is being received in tiny chunks or the user's machine is a potato. # noqa: E501
|
| 255 |
timestamps = get_speech_timestamps(audio.data, vad_opts)
|
| 256 |
if len(timestamps) == 0:
|
| 257 |
+
logger.info(f"No speech detected in the last {config.inactivity_window_seconds} seconds.")
|
|
|
|
|
|
|
| 258 |
break
|
| 259 |
elif (
|
| 260 |
# last speech end time
|
| 261 |
+
config.inactivity_window_seconds - timestamps[-1]["end"] / SAMPLES_PER_SECOND
|
|
|
|
| 262 |
>= config.max_inactivity_seconds
|
| 263 |
):
|
| 264 |
+
logger.info(f"Not enough speech in the last {config.inactivity_window_seconds} seconds.")
|
|
|
|
|
|
|
| 265 |
break
|
| 266 |
+
except TimeoutError:
|
| 267 |
+
logger.info(f"No data received in {config.max_no_data_seconds} seconds. Closing the connection.")
|
|
|
|
|
|
|
| 268 |
except WebSocketDisconnect as e:
|
| 269 |
logger.info(f"Client disconnected: {e}")
|
| 270 |
audio_stream.close()
|
|
|
|
| 275 |
ws: WebSocket,
|
| 276 |
model: Annotated[ModelName, Query()] = config.whisper.model,
|
| 277 |
language: Annotated[Language | None, Query()] = config.default_language,
|
| 278 |
+
response_format: Annotated[ResponseFormat, Query()] = config.default_response_format,
|
|
|
|
|
|
|
| 279 |
temperature: Annotated[float, Query()] = 0.0,
|
| 280 |
) -> None:
|
| 281 |
await ws.accept()
|
|
|
|
| 298 |
if response_format == ResponseFormat.TEXT:
|
| 299 |
await ws.send_text(transcription.text)
|
| 300 |
elif response_format == ResponseFormat.JSON:
|
| 301 |
+
await ws.send_json(TranscriptionJsonResponse.from_transcription(transcription).model_dump())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
elif response_format == ResponseFormat.VERBOSE_JSON:
|
| 303 |
+
await ws.send_json(TranscriptionVerboseJsonResponse.from_transcription(transcription).model_dump())
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
+
if ws.client_state != WebSocketState.DISCONNECTED:
|
| 306 |
logger.info("Closing the connection.")
|
| 307 |
await ws.close()
|
| 308 |
|
faster_whisper_server/server_models.py
CHANGED
|
@@ -1,12 +1,15 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from typing import Literal
|
| 4 |
|
| 5 |
-
from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
|
| 6 |
from pydantic import BaseModel, ConfigDict, Field
|
| 7 |
|
| 8 |
from faster_whisper_server import utils
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
# https://platform.openai.com/docs/api-reference/audio/json-object
|
|
@@ -18,9 +21,7 @@ class TranscriptionJsonResponse(BaseModel):
|
|
| 18 |
return cls(text=utils.segments_text(segments))
|
| 19 |
|
| 20 |
@classmethod
|
| 21 |
-
def from_transcription(
|
| 22 |
-
cls, transcription: Transcription
|
| 23 |
-
) -> TranscriptionJsonResponse:
|
| 24 |
return cls(text=transcription.text)
|
| 25 |
|
| 26 |
|
|
@@ -78,18 +79,12 @@ class TranscriptionVerboseJsonResponse(BaseModel):
|
|
| 78 |
segments: list[SegmentObject]
|
| 79 |
|
| 80 |
@classmethod
|
| 81 |
-
def from_segment(
|
| 82 |
-
cls, segment: Segment, transcription_info: TranscriptionInfo
|
| 83 |
-
) -> TranscriptionVerboseJsonResponse:
|
| 84 |
return cls(
|
| 85 |
language=transcription_info.language,
|
| 86 |
duration=segment.end - segment.start,
|
| 87 |
text=segment.text,
|
| 88 |
-
words=(
|
| 89 |
-
[WordObject.from_word(word) for word in segment.words]
|
| 90 |
-
if isinstance(segment.words, list)
|
| 91 |
-
else []
|
| 92 |
-
),
|
| 93 |
segments=[SegmentObject.from_segment(segment)],
|
| 94 |
)
|
| 95 |
|
|
@@ -102,16 +97,11 @@ class TranscriptionVerboseJsonResponse(BaseModel):
|
|
| 102 |
duration=transcription_info.duration,
|
| 103 |
text=utils.segments_text(segments),
|
| 104 |
segments=[SegmentObject.from_segment(segment) for segment in segments],
|
| 105 |
-
words=[
|
| 106 |
-
WordObject.from_word(word)
|
| 107 |
-
for word in utils.words_from_segments(segments)
|
| 108 |
-
],
|
| 109 |
)
|
| 110 |
|
| 111 |
@classmethod
|
| 112 |
-
def from_transcription(
|
| 113 |
-
cls, transcription: Transcription
|
| 114 |
-
) -> TranscriptionVerboseJsonResponse:
|
| 115 |
return cls(
|
| 116 |
language="english", # FIX: hardcoded
|
| 117 |
duration=transcription.duration,
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from typing import TYPE_CHECKING, Literal
|
| 4 |
|
|
|
|
| 5 |
from pydantic import BaseModel, ConfigDict, Field
|
| 6 |
|
| 7 |
from faster_whisper_server import utils
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from faster_whisper.transcribe import Segment, TranscriptionInfo, Word
|
| 11 |
+
|
| 12 |
+
from faster_whisper_server.core import Transcription
|
| 13 |
|
| 14 |
|
| 15 |
# https://platform.openai.com/docs/api-reference/audio/json-object
|
|
|
|
| 21 |
return cls(text=utils.segments_text(segments))
|
| 22 |
|
| 23 |
@classmethod
|
| 24 |
+
def from_transcription(cls, transcription: Transcription) -> TranscriptionJsonResponse:
|
|
|
|
|
|
|
| 25 |
return cls(text=transcription.text)
|
| 26 |
|
| 27 |
|
|
|
|
| 79 |
segments: list[SegmentObject]
|
| 80 |
|
| 81 |
@classmethod
|
| 82 |
+
def from_segment(cls, segment: Segment, transcription_info: TranscriptionInfo) -> TranscriptionVerboseJsonResponse:
|
|
|
|
|
|
|
| 83 |
return cls(
|
| 84 |
language=transcription_info.language,
|
| 85 |
duration=segment.end - segment.start,
|
| 86 |
text=segment.text,
|
| 87 |
+
words=([WordObject.from_word(word) for word in segment.words] if isinstance(segment.words, list) else []),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
segments=[SegmentObject.from_segment(segment)],
|
| 89 |
)
|
| 90 |
|
|
|
|
| 97 |
duration=transcription_info.duration,
|
| 98 |
text=utils.segments_text(segments),
|
| 99 |
segments=[SegmentObject.from_segment(segment) for segment in segments],
|
| 100 |
+
words=[WordObject.from_word(word) for word in utils.words_from_segments(segments)],
|
|
|
|
|
|
|
|
|
|
| 101 |
)
|
| 102 |
|
| 103 |
@classmethod
|
| 104 |
+
def from_transcription(cls, transcription: Transcription) -> TranscriptionVerboseJsonResponse:
|
|
|
|
|
|
|
| 105 |
return cls(
|
| 106 |
language="english", # FIX: hardcoded
|
| 107 |
duration=transcription.duration,
|
faster_whisper_server/transcriber.py
CHANGED
|
@@ -1,8 +1,7 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
from typing import
|
| 4 |
|
| 5 |
-
from faster_whisper_server.asr import FasterWhisperASR
|
| 6 |
from faster_whisper_server.audio import Audio, AudioStream
|
| 7 |
from faster_whisper_server.config import config
|
| 8 |
from faster_whisper_server.core import (
|
|
@@ -13,6 +12,11 @@ from faster_whisper_server.core import (
|
|
| 13 |
)
|
| 14 |
from faster_whisper_server.logger import logger
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class LocalAgreement:
|
| 18 |
def __init__(self) -> None:
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
+
from typing import TYPE_CHECKING
|
| 4 |
|
|
|
|
| 5 |
from faster_whisper_server.audio import Audio, AudioStream
|
| 6 |
from faster_whisper_server.config import config
|
| 7 |
from faster_whisper_server.core import (
|
|
|
|
| 12 |
)
|
| 13 |
from faster_whisper_server.logger import logger
|
| 14 |
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from collections.abc import AsyncGenerator
|
| 17 |
+
|
| 18 |
+
from faster_whisper_server.asr import FasterWhisperASR
|
| 19 |
+
|
| 20 |
|
| 21 |
class LocalAgreement:
|
| 22 |
def __init__(self) -> None:
|
pyproject.toml
CHANGED
|
@@ -28,18 +28,35 @@ target-version = "py312"
|
|
| 28 |
[tool.ruff.lint]
|
| 29 |
select = ["ALL"]
|
| 30 |
ignore = [
|
| 31 |
-
"
|
|
|
|
| 32 |
"ERA", # allow commented out code
|
| 33 |
-
"
|
| 34 |
-
"FIX002", # disable TODO warnings
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
"COM812", # trailing comma
|
| 37 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"S101", # allow assert
|
| 39 |
-
"
|
| 40 |
"S603", # subprocess untrusted input
|
| 41 |
-
|
| 42 |
-
"
|
|
|
|
|
|
|
|
|
|
| 43 |
]
|
| 44 |
|
| 45 |
[tool.ruff.lint.isort]
|
|
|
|
| 28 |
[tool.ruff.lint]
|
| 29 |
select = ["ALL"]
|
| 30 |
ignore = [
|
| 31 |
+
"FIX",
|
| 32 |
+
"TD", # disable todo warnings
|
| 33 |
"ERA", # allow commented out code
|
| 34 |
+
"PTH",
|
|
|
|
| 35 |
|
| 36 |
+
"ANN003", # missing kwargs
|
| 37 |
+
"ANN101", # missing self type
|
| 38 |
+
"ANN102", # missing cls
|
| 39 |
+
"B006",
|
| 40 |
+
"B008",
|
| 41 |
"COM812", # trailing comma
|
| 42 |
+
"D10", # disabled required docstrings
|
| 43 |
+
"D401",
|
| 44 |
+
"EM102",
|
| 45 |
+
"FBT001",
|
| 46 |
+
"FBT002",
|
| 47 |
+
"PLR0913",
|
| 48 |
+
"PLR2004", # magic
|
| 49 |
+
"RET504",
|
| 50 |
+
"RET505",
|
| 51 |
+
"RET508",
|
| 52 |
"S101", # allow assert
|
| 53 |
+
"S104",
|
| 54 |
"S603", # subprocess untrusted input
|
| 55 |
+
"SIM102",
|
| 56 |
+
"T201", # print
|
| 57 |
+
"TRY003",
|
| 58 |
+
"W505",
|
| 59 |
+
"ISC001" # recommended to disable for formatting
|
| 60 |
]
|
| 61 |
|
| 62 |
[tool.ruff.lint.isort]
|
tests/api_model_test.py
CHANGED
|
@@ -4,9 +4,7 @@ from faster_whisper_server.server_models import ModelObject
|
|
| 4 |
|
| 5 |
MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
|
| 6 |
MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
|
| 7 |
-
MIN_EXPECTED_NUMBER_OF_MODELS =
|
| 8 |
-
200 # At the time of the test creation there are 228 models
|
| 9 |
-
)
|
| 10 |
|
| 11 |
|
| 12 |
# HACK: because ModelObject(**data) doesn't work
|
|
@@ -19,20 +17,20 @@ def model_dict_to_object(model_dict: dict) -> ModelObject:
|
|
| 19 |
)
|
| 20 |
|
| 21 |
|
| 22 |
-
def test_list_models(client: TestClient):
|
| 23 |
response = client.get("/v1/models")
|
| 24 |
data = response.json()
|
| 25 |
models = [model_dict_to_object(model_dict) for model_dict in data]
|
| 26 |
assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
|
| 27 |
|
| 28 |
|
| 29 |
-
def test_model_exists(client: TestClient):
|
| 30 |
response = client.get(f"/v1/models/{MODEL_THAT_EXISTS}")
|
| 31 |
data = response.json()
|
| 32 |
model = model_dict_to_object(data)
|
| 33 |
assert model.id == MODEL_THAT_EXISTS
|
| 34 |
|
| 35 |
|
| 36 |
-
def test_model_does_not_exist(client: TestClient):
|
| 37 |
response = client.get(f"/v1/models/{MODEL_THAT_DOES_NOT_EXIST}")
|
| 38 |
assert response.status_code == 404
|
|
|
|
| 4 |
|
| 5 |
MODEL_THAT_EXISTS = "Systran/faster-whisper-tiny.en"
|
| 6 |
MODEL_THAT_DOES_NOT_EXIST = "i-do-not-exist"
|
| 7 |
+
MIN_EXPECTED_NUMBER_OF_MODELS = 200 # At the time of the test creation there are 228 models
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
# HACK: because ModelObject(**data) doesn't work
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
|
| 20 |
+
def test_list_models(client: TestClient) -> None:
|
| 21 |
response = client.get("/v1/models")
|
| 22 |
data = response.json()
|
| 23 |
models = [model_dict_to_object(model_dict) for model_dict in data]
|
| 24 |
assert len(models) > MIN_EXPECTED_NUMBER_OF_MODELS
|
| 25 |
|
| 26 |
|
| 27 |
+
def test_model_exists(client: TestClient) -> None:
|
| 28 |
response = client.get(f"/v1/models/{MODEL_THAT_EXISTS}")
|
| 29 |
data = response.json()
|
| 30 |
model = model_dict_to_object(data)
|
| 31 |
assert model.id == MODEL_THAT_EXISTS
|
| 32 |
|
| 33 |
|
| 34 |
+
def test_model_does_not_exist(client: TestClient) -> None:
|
| 35 |
response = client.get(f"/v1/models/{MODEL_THAT_DOES_NOT_EXIST}")
|
| 36 |
assert response.status_code == 404
|
tests/app_test.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import time
|
| 4 |
-
from typing import Generator
|
| 5 |
|
| 6 |
-
import pytest
|
| 7 |
from fastapi.testclient import TestClient
|
|
|
|
| 8 |
from starlette.testclient import WebSocketTestSession
|
| 9 |
|
| 10 |
from faster_whisper_server.config import BYTES_PER_SECOND
|
|
@@ -22,35 +22,31 @@ def ws(client: TestClient) -> Generator[WebSocketTestSession, None, None]:
|
|
| 22 |
yield ws
|
| 23 |
|
| 24 |
|
| 25 |
-
def get_audio_file_paths():
|
| 26 |
-
file_paths = []
|
| 27 |
directory = "tests/data"
|
| 28 |
for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]):
|
| 29 |
-
file_paths.append(os.path.join(directory, filename))
|
| 30 |
return file_paths
|
| 31 |
|
| 32 |
|
| 33 |
file_paths = get_audio_file_paths()
|
| 34 |
|
| 35 |
|
| 36 |
-
def stream_audio_data(
|
| 37 |
-
ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0
|
| 38 |
-
):
|
| 39 |
for i in range(0, len(data), chunk_size):
|
| 40 |
ws.send_bytes(data[i : i + chunk_size])
|
| 41 |
delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed
|
| 42 |
time.sleep(delay)
|
| 43 |
|
| 44 |
|
| 45 |
-
def transcribe_audio_data(
|
| 46 |
-
client: TestClient, data: bytes
|
| 47 |
-
) -> TranscriptionVerboseJsonResponse:
|
| 48 |
response = client.post(
|
| 49 |
TRANSCRIBE_ENDPOINT,
|
| 50 |
files={"file": ("audio.raw", data, "audio/raw")},
|
| 51 |
)
|
| 52 |
data = json.loads(response.json()) # TODO: figure this out
|
| 53 |
-
return TranscriptionVerboseJsonResponse(**data) #
|
| 54 |
|
| 55 |
|
| 56 |
# @pytest.mark.parametrize("file_path", file_paths)
|
|
@@ -60,7 +56,7 @@ def transcribe_audio_data(
|
|
| 60 |
# with open(file_path, "rb") as file:
|
| 61 |
# data = file.read()
|
| 62 |
#
|
| 63 |
-
# streaming_transcription: TranscriptionVerboseJsonResponse = None # type: ignore
|
| 64 |
# thread = threading.Thread(
|
| 65 |
# target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
|
| 66 |
# )
|
|
|
|
| 1 |
+
from collections.abc import Generator
|
| 2 |
import json
|
| 3 |
import os
|
| 4 |
import time
|
|
|
|
| 5 |
|
|
|
|
| 6 |
from fastapi.testclient import TestClient
|
| 7 |
+
import pytest
|
| 8 |
from starlette.testclient import WebSocketTestSession
|
| 9 |
|
| 10 |
from faster_whisper_server.config import BYTES_PER_SECOND
|
|
|
|
| 22 |
yield ws
|
| 23 |
|
| 24 |
|
| 25 |
+
def get_audio_file_paths() -> list[str]:
|
| 26 |
+
file_paths: list[str] = []
|
| 27 |
directory = "tests/data"
|
| 28 |
for filename in sorted(os.listdir(directory)[:AUDIO_FILES_LIMIT]):
|
| 29 |
+
file_paths.append(os.path.join(directory, filename)) # noqa: PERF401
|
| 30 |
return file_paths
|
| 31 |
|
| 32 |
|
| 33 |
file_paths = get_audio_file_paths()
|
| 34 |
|
| 35 |
|
| 36 |
+
def stream_audio_data(ws: WebSocketTestSession, data: bytes, *, chunk_size: int = 4000, speed: float = 1.0) -> None:
|
|
|
|
|
|
|
| 37 |
for i in range(0, len(data), chunk_size):
|
| 38 |
ws.send_bytes(data[i : i + chunk_size])
|
| 39 |
delay = len(data[i : i + chunk_size]) / BYTES_PER_SECOND / speed
|
| 40 |
time.sleep(delay)
|
| 41 |
|
| 42 |
|
| 43 |
+
def transcribe_audio_data(client: TestClient, data: bytes) -> TranscriptionVerboseJsonResponse:
|
|
|
|
|
|
|
| 44 |
response = client.post(
|
| 45 |
TRANSCRIBE_ENDPOINT,
|
| 46 |
files={"file": ("audio.raw", data, "audio/raw")},
|
| 47 |
)
|
| 48 |
data = json.loads(response.json()) # TODO: figure this out
|
| 49 |
+
return TranscriptionVerboseJsonResponse(**data) # pyright: ignore[reportCallIssue]
|
| 50 |
|
| 51 |
|
| 52 |
# @pytest.mark.parametrize("file_path", file_paths)
|
|
|
|
| 56 |
# with open(file_path, "rb") as file:
|
| 57 |
# data = file.read()
|
| 58 |
#
|
| 59 |
+
# streaming_transcription: TranscriptionVerboseJsonResponse = None # type: ignore # noqa: PGH003
|
| 60 |
# thread = threading.Thread(
|
| 61 |
# target=stream_audio_data, args=(ws, data), kwargs={"speed": 4.0}
|
| 62 |
# )
|
tests/conftest.py
CHANGED
|
@@ -1,18 +1,15 @@
|
|
|
|
|
| 1 |
import logging
|
| 2 |
-
import os
|
| 3 |
-
from typing import Generator
|
| 4 |
|
| 5 |
-
import pytest
|
| 6 |
from fastapi.testclient import TestClient
|
|
|
|
| 7 |
|
| 8 |
-
|
| 9 |
-
os.environ["WHISPER_MODEL"] = "Systran/faster-whisper-tiny.en"
|
| 10 |
-
from faster_whisper_server.main import app # noqa: E402
|
| 11 |
|
| 12 |
disable_loggers = ["multipart.multipart", "faster_whisper"]
|
| 13 |
|
| 14 |
|
| 15 |
-
def pytest_configure():
|
| 16 |
for logger_name in disable_loggers:
|
| 17 |
logger = logging.getLogger(logger_name)
|
| 18 |
logger.disabled = True
|
|
|
|
| 1 |
+
from collections.abc import Generator
|
| 2 |
import logging
|
|
|
|
|
|
|
| 3 |
|
|
|
|
| 4 |
from fastapi.testclient import TestClient
|
| 5 |
+
import pytest
|
| 6 |
|
| 7 |
+
from faster_whisper_server.main import app
|
|
|
|
|
|
|
| 8 |
|
| 9 |
disable_loggers = ["multipart.multipart", "faster_whisper"]
|
| 10 |
|
| 11 |
|
| 12 |
+
def pytest_configure() -> None:
|
| 13 |
for logger_name in disable_loggers:
|
| 14 |
logger = logging.getLogger(logger_name)
|
| 15 |
logger.disabled = True
|
tests/sse_test.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
|
| 4 |
-
import pytest
|
| 5 |
from fastapi.testclient import TestClient
|
| 6 |
from httpx_sse import connect_sse
|
|
|
|
| 7 |
|
| 8 |
from faster_whisper_server.server_models import (
|
| 9 |
TranscriptionJsonResponse,
|
|
@@ -17,15 +17,11 @@ ENDPOINTS = [
|
|
| 17 |
]
|
| 18 |
|
| 19 |
|
| 20 |
-
parameters = [
|
| 21 |
-
(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS
|
| 22 |
-
]
|
| 23 |
|
| 24 |
|
| 25 |
-
@pytest.mark.parametrize("file_path,endpoint", parameters)
|
| 26 |
-
def test_streaming_transcription_text(
|
| 27 |
-
client: TestClient, file_path: str, endpoint: str
|
| 28 |
-
):
|
| 29 |
extension = os.path.splitext(file_path)[1]
|
| 30 |
with open(file_path, "rb") as f:
|
| 31 |
data = f.read()
|
|
@@ -36,15 +32,11 @@ def test_streaming_transcription_text(
|
|
| 36 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
|
| 37 |
for event in event_source.iter_sse():
|
| 38 |
print(event)
|
| 39 |
-
assert (
|
| 40 |
-
len(event.data) > 1
|
| 41 |
-
) # HACK: 1 because of the space character that's always prepended
|
| 42 |
|
| 43 |
|
| 44 |
-
@pytest.mark.parametrize("file_path,endpoint", parameters)
|
| 45 |
-
def test_streaming_transcription_json(
|
| 46 |
-
client: TestClient, file_path: str, endpoint: str
|
| 47 |
-
):
|
| 48 |
extension = os.path.splitext(file_path)[1]
|
| 49 |
with open(file_path, "rb") as f:
|
| 50 |
data = f.read()
|
|
@@ -57,10 +49,8 @@ def test_streaming_transcription_json(
|
|
| 57 |
TranscriptionJsonResponse(**json.loads(event.data))
|
| 58 |
|
| 59 |
|
| 60 |
-
@pytest.mark.parametrize("file_path,endpoint", parameters)
|
| 61 |
-
def test_streaming_transcription_verbose_json(
|
| 62 |
-
client: TestClient, file_path: str, endpoint: str
|
| 63 |
-
):
|
| 64 |
extension = os.path.splitext(file_path)[1]
|
| 65 |
with open(file_path, "rb") as f:
|
| 66 |
data = f.read()
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
|
|
|
|
| 4 |
from fastapi.testclient import TestClient
|
| 5 |
from httpx_sse import connect_sse
|
| 6 |
+
import pytest
|
| 7 |
|
| 8 |
from faster_whisper_server.server_models import (
|
| 9 |
TranscriptionJsonResponse,
|
|
|
|
| 17 |
]
|
| 18 |
|
| 19 |
|
| 20 |
+
parameters = [(file_path, endpoint) for endpoint in ENDPOINTS for file_path in FILE_PATHS]
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
+
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
|
| 24 |
+
def test_streaming_transcription_text(client: TestClient, file_path: str, endpoint: str) -> None:
|
|
|
|
|
|
|
| 25 |
extension = os.path.splitext(file_path)[1]
|
| 26 |
with open(file_path, "rb") as f:
|
| 27 |
data = f.read()
|
|
|
|
| 32 |
with connect_sse(client, "POST", endpoint, **kwargs) as event_source:
|
| 33 |
for event in event_source.iter_sse():
|
| 34 |
print(event)
|
| 35 |
+
assert len(event.data) > 1 # HACK: 1 because of the space character that's always prepended
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
+
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
|
| 39 |
+
def test_streaming_transcription_json(client: TestClient, file_path: str, endpoint: str) -> None:
|
|
|
|
|
|
|
| 40 |
extension = os.path.splitext(file_path)[1]
|
| 41 |
with open(file_path, "rb") as f:
|
| 42 |
data = f.read()
|
|
|
|
| 49 |
TranscriptionJsonResponse(**json.loads(event.data))
|
| 50 |
|
| 51 |
|
| 52 |
+
@pytest.mark.parametrize(("file_path", "endpoint"), parameters)
|
| 53 |
+
def test_streaming_transcription_verbose_json(client: TestClient, file_path: str, endpoint: str) -> None:
|
|
|
|
|
|
|
| 54 |
extension = os.path.splitext(file_path)[1]
|
| 55 |
with open(file_path, "rb") as f:
|
| 56 |
data = f.read()
|