|
import asyncio |
|
import zlib |
|
from functools import lru_cache |
|
from io import BytesIO |
|
from pathlib import Path |
|
from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING |
|
|
|
import numpy as np |
|
from hfendpoints.errors.config import UnsupportedModelArchitecture |
|
from hfendpoints.openai import Context, run |
|
from hfendpoints.openai.audio import ( |
|
AutomaticSpeechRecognitionEndpoint, |
|
TranscriptionRequest, |
|
TranscriptionResponse, |
|
TranscriptionResponseKind, |
|
SegmentBuilder, |
|
Segment, |
|
Transcription, |
|
VerboseTranscription, |
|
) |
|
from librosa import load as load_audio, get_duration |
|
from loguru import logger |
|
from transformers import AutoConfig |
|
from vllm import ( |
|
AsyncEngineArgs, |
|
AsyncLLMEngine, |
|
SamplingParams, |
|
) |
|
|
|
from hfendpoints import EndpointConfig, Handler, ensure_supported_architectures |
|
|
|
if TYPE_CHECKING: |
|
from transformers import PreTrainedTokenizer |
|
from vllm import CompletionOutput, RequestOutput |
|
from vllm.sequence import SampleLogprobs |
|
|
|
SUPPORTED_MODEL_ARCHITECTURES = ["WhisperForConditionalGeneration"] |
|
|
|
|
|
def chunk_audio_with_duration( |
|
audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int |
|
) -> Sequence[np.ndarray]: |
|
""" |
|
Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`. |
|
Chunks are evenly distributed except the last one which might be shorter |
|
:param audio: The mono timeseries waveform of the audio |
|
:param maximum_duration_sec: The maximum length, in seconds, for each chunk |
|
:param sampling_rate: The number of samples to represent one second of audio |
|
:return: List of numpy array representing the chunk |
|
""" |
|
|
|
|
|
max_duration_samples = sampling_rate * maximum_duration_sec |
|
padding = max_duration_samples - np.remainder(len(audio), max_duration_samples) |
|
audio = np.pad(audio, (0, padding), constant_values=0.0) |
|
return np.split(audio, len(audio) // max_duration_samples) |
|
|
|
|
|
def compression_ratio(text: str) -> float: |
|
""" |
|
|
|
:param text: |
|
:return: |
|
""" |
|
text_bytes = text.encode("utf-8") |
|
return len(text_bytes) / len(zlib.compress(text_bytes)) |
|
|
|
|
|
def create_prompt( |
|
audio: np.ndarray, |
|
sampling_rate: int, |
|
language: int, |
|
timestamp_marker: int, |
|
): |
|
""" |
|
Generate the right prompt with the specific parameters to submit for inference over Whisper |
|
:param audio: PCM data containing audio signal representation |
|
:param sampling_rate: Number of samples in one second of audio |
|
:param language: Token id representing the language of the audio content |
|
:param timestamp_marker: Token id representing the temporal position within the audio content for this segment |
|
:return: Dictionary with all the prefilled value to call `generate` |
|
""" |
|
return { |
|
"encoder_prompt": { |
|
"prompt": "", |
|
"multi_modal_data": {"audio": (audio, sampling_rate)}, |
|
}, |
|
"decoder_prompt": { |
|
"prompt_token_ids": [ |
|
50258, |
|
language, |
|
50360, |
|
timestamp_marker, |
|
] |
|
}, |
|
} |
|
|
|
|
|
def create_params( |
|
max_tokens: int, temperature: float, is_verbose: bool |
|
) -> "SamplingParams": |
|
""" |
|
Create sampling parameters to submit for inference through vLLM `generate` |
|
:param max_tokens: Maximum number of tokens to generate |
|
:param temperature: Sampling temperature for the softmax |
|
:param is_verbose: Flag indicating whether the response is required to contains verbose output |
|
:return: `SamplingParams` |
|
""" |
|
return SamplingParams.from_optional( |
|
|
|
max_tokens=max_tokens, |
|
skip_special_tokens=False, |
|
detokenize=False, |
|
temperature=temperature, |
|
logprobs=1 if is_verbose else None, |
|
) |
|
|
|
|
|
def get_avg_logprob(logprobs: "SampleLogprobs") -> float: |
|
""" |
|
Aggregate the log probabilities over all generation steps by taking the log probability of the generated token |
|
:param logprobs: Iterable of log probabilities for all the generation steps |
|
:return: Averaged log probability as floating-point number |
|
""" |
|
sum_logp = sum(next(iter(_step_.values())).logprob for _step_ in logprobs) |
|
return sum_logp / float(len(logprobs)) |
|
|
|
|
|
def process_chunk( |
|
tokenizer: "PreTrainedTokenizer", |
|
ids: np.ndarray, |
|
logprobs: "SampleLogprobs", |
|
request: TranscriptionRequest, |
|
segment_offset: int, |
|
timestamp_offset: int, |
|
) -> Generator: |
|
""" |
|
Decode a single transcribed audio chunk and generates all the segments associated |
|
:param tokenizer: |
|
:param ids: |
|
:param logprobs: |
|
:param request: |
|
:param segment_offset: |
|
:param timestamp_offset: |
|
:return: |
|
""" |
|
|
|
k_timestamp_token = lru_cache(tokenizer.convert_tokens_to_ids)(f"<|0.00|>") |
|
|
|
|
|
|
|
|
|
|
|
|
|
timestamps_mask = ids >= k_timestamp_token |
|
|
|
if np.any(timestamps_mask): |
|
|
|
is_single_ending_timestamp = np.array_equal(timestamps_mask[-2:], [False, True]) |
|
|
|
|
|
timestamp_start, timestamp_end = 0.0, 0.0 |
|
slice_start = 0 |
|
|
|
for t, position in enumerate(np.flatnonzero(timestamps_mask)): |
|
timestamp = float(tokenizer.convert_ids_to_tokens([ids[position]])[0][2:-2]) |
|
|
|
if t % 2 == 0: |
|
timestamp_end = timestamp |
|
|
|
|
|
segment_ids = ids[slice_start:position] |
|
segment_text = tokenizer.decode(segment_ids) |
|
|
|
|
|
avg_logprob = get_avg_logprob(logprobs) if logprobs else float("nan") |
|
|
|
|
|
|
|
|
|
|
|
|
|
segment = ( |
|
SegmentBuilder() |
|
.id(segment_offset + t) |
|
.start(timestamp_offset + timestamp_start) |
|
.end(timestamp_offset + timestamp_end) |
|
.text(segment_text) |
|
.tokens(segment_ids.tolist()) |
|
.temperature(request.temperature) |
|
.avg_logprob(avg_logprob) |
|
.compression_ratio(compression_ratio(segment_text)) |
|
.build() |
|
) |
|
|
|
yield segment, is_single_ending_timestamp |
|
|
|
|
|
slice_start = position |
|
else: |
|
timestamp_start = timestamp |
|
|
|
|
|
def process_chunks( |
|
tokenizer: "PreTrainedTokenizer", |
|
chunks: List["RequestOutput"], |
|
request: TranscriptionRequest, |
|
) -> Tuple[List[Segment], str]: |
|
""" |
|
Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not |
|
:param tokenizer: The tokenizer to use for decoding tokens |
|
:param chunks: Transcribed audio chunks |
|
:param request: Received request from the user |
|
:return: `Tuple[List[Segment], str]` holding all the consolidated segments along with full transcribed text |
|
""" |
|
|
|
|
|
materialized_segments, materialized_segments_tokens_acc = [], [] |
|
|
|
|
|
for idx, chunk in enumerate(chunks): |
|
time_offset = idx * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC |
|
segment_offset = len(materialized_segments) |
|
|
|
generation: "CompletionOutput" = chunk.outputs[-1] |
|
ids: np.ndarray = np.asarray(generation.token_ids) |
|
logprobs = generation.logprobs |
|
|
|
for segment, _is_continuation in process_chunk( |
|
tokenizer, ids, logprobs, request, segment_offset, time_offset |
|
): |
|
materialized_segments.append(segment) |
|
|
|
|
|
materialized_segments_tokens_acc += generation.token_ids |
|
|
|
text = tokenizer.decode( |
|
materialized_segments_tokens_acc, |
|
skip_special_tokens=True, |
|
clean_up_tokenization_spaces=True, |
|
) |
|
|
|
return materialized_segments, text |
|
|
|
|
|
class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]): |
|
WHISPER_SEGMENT_DURATION_SEC = 30 |
|
WHISPER_SAMPLING_RATE = 22050 |
|
|
|
__slots__ = ("_engine",) |
|
|
|
def __init__(self, model_id_or_path: str): |
|
super().__init__(model_id_or_path) |
|
|
|
self._engine = AsyncLLMEngine.from_engine_args( |
|
AsyncEngineArgs( |
|
model_id_or_path, |
|
task="transcription", |
|
device="auto", |
|
dtype="bfloat16", |
|
kv_cache_dtype="fp8", |
|
enforce_eager=False, |
|
enable_prefix_caching=True, |
|
max_logprobs=1, |
|
disable_log_requests=True, |
|
) |
|
) |
|
|
|
async def transcribe( |
|
self, |
|
ctx: Context, |
|
request: TranscriptionRequest, |
|
tokenizer: "PreTrainedTokenizer", |
|
audio_chunks: Iterable[np.ndarray], |
|
params: "SamplingParams", |
|
) -> (List[Segment], str): |
|
async def __agenerate__(request_id: str, prompt, params): |
|
""" |
|
Helper method to unroll asynchronous generator and return the last element |
|
:param request_id: Unique identifier for this request |
|
:param prompt: The prompt to submit for inference on vLLM through `generate(...)` |
|
:param params: The parameters passed along with the prompt for inference on vLLM through `generate(...)` |
|
:return: `CompletionOutput` |
|
""" |
|
|
|
async for step in self._engine.generate(prompt, params, request_id): |
|
pass |
|
return step |
|
|
|
|
|
convert_tokens_to_ids = lru_cache(tokenizer.convert_tokens_to_ids) |
|
|
|
coro_handles = [] |
|
for audio_chunk_id, audio_chunk in enumerate(audio_chunks): |
|
|
|
request_id = f"{ctx.request_id}-{audio_chunk_id}" |
|
|
|
|
|
timestamp = audio_chunk_id * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC |
|
|
|
|
|
is_verbose = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON |
|
language = convert_tokens_to_ids(f"<|{request.language}|>") |
|
timestamp = convert_tokens_to_ids( |
|
f"<|0.00|>" if is_verbose else "<|notimestamps|>" |
|
) |
|
prompt = create_prompt( |
|
audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, language, timestamp |
|
) |
|
|
|
|
|
coro_handles += [ |
|
asyncio.create_task(__agenerate__(request_id, prompt, params)) |
|
] |
|
|
|
|
|
text_chunks = await asyncio.gather(*coro_handles) |
|
|
|
|
|
segments, text = await asyncio.get_event_loop().run_in_executor( |
|
None, process_chunks, tokenizer, text_chunks, request |
|
) |
|
return segments, text |
|
|
|
async def __call__( |
|
self, request: TranscriptionRequest, ctx: Context |
|
) -> TranscriptionResponse: |
|
with logger.contextualize(request_id=ctx.request_id): |
|
with memoryview(request) as audio: |
|
|
|
|
|
is_verbose = ( |
|
request.response_kind == TranscriptionResponseKind.VERBOSE_JSON |
|
) |
|
|
|
|
|
tokenizer = asyncio.create_task(self._engine.get_tokenizer()) |
|
model_config = asyncio.create_task(self._engine.get_model_config()) |
|
|
|
|
|
|
|
(waveform, sampling) = load_audio(BytesIO(audio), sr=22050, mono=True) |
|
logger.debug( |
|
f"Successfully decoded {len(waveform)} bytes PCM audio chunk" |
|
) |
|
|
|
|
|
max_tokens = (await model_config).max_model_len - 4 |
|
params = create_params(max_tokens, request.temperature, is_verbose) |
|
|
|
|
|
audio_chunks = chunk_audio_with_duration( |
|
waveform, |
|
maximum_duration_sec=WhisperHandler.WHISPER_SEGMENT_DURATION_SEC, |
|
sampling_rate=WhisperHandler.WHISPER_SAMPLING_RATE, |
|
) |
|
|
|
|
|
segments, text = await self.transcribe( |
|
ctx, request, await tokenizer, audio_chunks, params |
|
) |
|
|
|
match request.response_kind: |
|
case TranscriptionResponseKind.VERBOSE_JSON: |
|
return TranscriptionResponse.verbose( |
|
VerboseTranscription( |
|
text=text, |
|
duration=get_duration(y=waveform, sr=sampling), |
|
language=request.language, |
|
segments=segments, |
|
|
|
) |
|
) |
|
case TranscriptionResponseKind.JSON: |
|
return TranscriptionResponse.json(text) |
|
|
|
case TranscriptionResponseKind.TEXT: |
|
return TranscriptionResponse.text(text) |
|
|
|
|
|
raise ValueError(f"Invalid response_kind ({request.response_kind})") |
|
|
|
|
|
def entrypoint(): |
|
|
|
endpoint_config = EndpointConfig.from_env() |
|
|
|
|
|
if (model_local_path := Path(endpoint_config.model_id)).exists(): |
|
if (config_local_path := (model_local_path / "config.json")).exists(): |
|
config = AutoConfig.from_pretrained(config_local_path) |
|
ensure_supported_architectures(config, SUPPORTED_MODEL_ARCHITECTURES) |
|
|
|
|
|
endpoint = AutomaticSpeechRecognitionEndpoint( |
|
WhisperHandler(endpoint_config.model_id) |
|
) |
|
|
|
|
|
run(endpoint, endpoint_config.interface, endpoint_config.port) |
|
|
|
|
|
if __name__ == "__main__": |
|
entrypoint() |
|
|