import asyncio import zlib from functools import lru_cache from io import BytesIO from pathlib import Path from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING import numpy as np from hfendpoints.errors.config import UnsupportedModelArchitecture from hfendpoints.openai import Context, run from hfendpoints.openai.audio import ( AutomaticSpeechRecognitionEndpoint, TranscriptionRequest, TranscriptionResponse, TranscriptionResponseKind, SegmentBuilder, Segment, Transcription, VerboseTranscription, ) from librosa import load as load_audio, get_duration from loguru import logger from transformers import AutoConfig from vllm import ( AsyncEngineArgs, AsyncLLMEngine, SamplingParams, ) from hfendpoints import EndpointConfig, Handler, ensure_supported_architectures if TYPE_CHECKING: from transformers import PreTrainedTokenizer from vllm import CompletionOutput, RequestOutput from vllm.sequence import SampleLogprobs SUPPORTED_MODEL_ARCHITECTURES = ["WhisperForConditionalGeneration"] def chunk_audio_with_duration( audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int ) -> Sequence[np.ndarray]: """ Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`. Chunks are evenly distributed except the last one which might be shorter :param audio: The mono timeseries waveform of the audio :param maximum_duration_sec: The maximum length, in seconds, for each chunk :param sampling_rate: The number of samples to represent one second of audio :return: List of numpy array representing the chunk """ # We pad the input so that every chunk length is `max_duration_sec` max_duration_samples = sampling_rate * maximum_duration_sec padding = max_duration_samples - np.remainder(len(audio), max_duration_samples) audio = np.pad(audio, (0, padding), constant_values=0.0) return np.split(audio, len(audio) // max_duration_samples) def compression_ratio(text: str) -> float: """ :param text: :return: """ text_bytes = text.encode("utf-8") return len(text_bytes) / len(zlib.compress(text_bytes)) def create_prompt( audio: np.ndarray, sampling_rate: int, language: int, timestamp_marker: int, ): """ Generate the right prompt with the specific parameters to submit for inference over Whisper :param audio: PCM data containing audio signal representation :param sampling_rate: Number of samples in one second of audio :param language: Token id representing the language of the audio content :param timestamp_marker: Token id representing the temporal position within the audio content for this segment :return: Dictionary with all the prefilled value to call `generate` """ return { "encoder_prompt": { "prompt": "", "multi_modal_data": {"audio": (audio, sampling_rate)}, }, "decoder_prompt": { "prompt_token_ids": [ 50258, language, 50360, timestamp_marker, ] }, } def create_params( max_tokens: int, temperature: float, is_verbose: bool ) -> "SamplingParams": """ Create sampling parameters to submit for inference through vLLM `generate` :param max_tokens: Maximum number of tokens to generate :param temperature: Sampling temperature for the softmax :param is_verbose: Flag indicating whether the response is required to contains verbose output :return: `SamplingParams` """ return SamplingParams.from_optional( # output_kind=RequestOutputKind.FINAL_ONLY, # Change if streaming max_tokens=max_tokens, skip_special_tokens=False, detokenize=False, temperature=temperature, logprobs=1 if is_verbose else None, ) def get_avg_logprob(logprobs: "SampleLogprobs") -> float: """ Aggregate the log probabilities over all generation steps by taking the log probability of the generated token :param logprobs: Iterable of log probabilities for all the generation steps :return: Averaged log probability as floating-point number """ sum_logp = sum(next(iter(_step_.values())).logprob for _step_ in logprobs) return sum_logp / float(len(logprobs)) def process_chunk( tokenizer: "PreTrainedTokenizer", ids: np.ndarray, logprobs: "SampleLogprobs", request: TranscriptionRequest, segment_offset: int, timestamp_offset: int, ) -> Generator: """ Decode a single transcribed audio chunk and generates all the segments associated :param tokenizer: :param ids: :param logprobs: :param request: :param segment_offset: :param timestamp_offset: :return: """ # Some constants k_timestamp_token = lru_cache(tokenizer.convert_tokens_to_ids)(f"<|0.00|>") # Detect start of transcript token # sot_mask = ids == k_sot_token # Timestamps are expected to have ids greater than token_id(<|0.00|>) # We create a mask for all the potential tokens which are >= token_id(<|0.00|>) timestamps_mask = ids >= k_timestamp_token if np.any(timestamps_mask): # If we have a timestamp token, we need to check whether it's a final token or a final then the next is_single_ending_timestamp = np.array_equal(timestamps_mask[-2:], [False, True]) # Iterate over timestamps timestamp_start, timestamp_end = 0.0, 0.0 slice_start = 0 for t, position in enumerate(np.flatnonzero(timestamps_mask)): timestamp = float(tokenizer.convert_ids_to_tokens([ids[position]])[0][2:-2]) if t % 2 == 0: timestamp_end = timestamp # Retrieve segment info segment_ids = ids[slice_start:position] segment_text = tokenizer.decode(segment_ids) # Compute the avg_logprob avg_logprob = get_avg_logprob(logprobs) if logprobs else float("nan") # no-speech logprob # no_speech_token_id = lru_cache(tokenizer.convert_tokens_to_ids("<|nospeech|>")) # no_speech_logprob = logprobs[no_speech_token_id] # Materialize the segment in memory segment = ( SegmentBuilder() .id(segment_offset + t) .start(timestamp_offset + timestamp_start) .end(timestamp_offset + timestamp_end) .text(segment_text) .tokens(segment_ids.tolist()) .temperature(request.temperature) .avg_logprob(avg_logprob) .compression_ratio(compression_ratio(segment_text)) .build() ) yield segment, is_single_ending_timestamp # Update the start position slice_start = position else: timestamp_start = timestamp def process_chunks( tokenizer: "PreTrainedTokenizer", chunks: List["RequestOutput"], request: TranscriptionRequest, ) -> Tuple[List[Segment], str]: """ Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not :param tokenizer: The tokenizer to use for decoding tokens :param chunks: Transcribed audio chunks :param request: Received request from the user :return: `Tuple[List[Segment], str]` holding all the consolidated segments along with full transcribed text """ # k_nospeech_token = tokenizer.convert_tokens_to_ids("<|nospeech|>") # k_sot_token = tokenizer.convert_tokens_to_ids("<|startoftranscript|>") materialized_segments, materialized_segments_tokens_acc = [], [] # Iterate over segments for idx, chunk in enumerate(chunks): time_offset = idx * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC segment_offset = len(materialized_segments) generation: "CompletionOutput" = chunk.outputs[-1] ids: np.ndarray = np.asarray(generation.token_ids) logprobs = generation.logprobs for segment, _is_continuation in process_chunk( tokenizer, ids, logprobs, request, segment_offset, time_offset ): materialized_segments.append(segment) # Accumulate the tokens for full decoding materialized_segments_tokens_acc += generation.token_ids text = tokenizer.decode( materialized_segments_tokens_acc, skip_special_tokens=True, clean_up_tokenization_spaces=True, ) return materialized_segments, text class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]): WHISPER_SEGMENT_DURATION_SEC = 30 WHISPER_SAMPLING_RATE = 22050 __slots__ = ("_engine",) def __init__(self, model_id_or_path: str): super().__init__(model_id_or_path) self._engine = AsyncLLMEngine.from_engine_args( AsyncEngineArgs( model_id_or_path, task="transcription", device="auto", dtype="bfloat16", kv_cache_dtype="fp8", enforce_eager=False, enable_prefix_caching=True, max_logprobs=1, # TODO(mfuntowicz) : Set from config? disable_log_requests=True, ) ) async def transcribe( self, ctx: Context, request: TranscriptionRequest, tokenizer: "PreTrainedTokenizer", audio_chunks: Iterable[np.ndarray], params: "SamplingParams", ) -> (List[Segment], str): async def __agenerate__(request_id: str, prompt, params): """ Helper method to unroll asynchronous generator and return the last element :param request_id: Unique identifier for this request :param prompt: The prompt to submit for inference on vLLM through `generate(...)` :param params: The parameters passed along with the prompt for inference on vLLM through `generate(...)` :return: `CompletionOutput` """ # Submit for inference on the segment & keep track of the background task async for step in self._engine.generate(prompt, params, request_id): pass return step # Wrap tokenizer results with LRU cache to avoid vocabulary lookup convert_tokens_to_ids = lru_cache(tokenizer.convert_tokens_to_ids) coro_handles = [] for audio_chunk_id, audio_chunk in enumerate(audio_chunks): # Generate suffixed request-id to submit and identify through vLLM scheduler request_id = f"{ctx.request_id}-{audio_chunk_id}" # Compute the starting time of the chunk as each consecutive chunk represents 30s worth of audio timestamp = audio_chunk_id * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC # Compute initial prompt for the segment is_verbose = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON language = convert_tokens_to_ids(f"<|{request.language}|>") timestamp = convert_tokens_to_ids( f"<|0.00|>" if is_verbose else "<|notimestamps|>" ) prompt = create_prompt( audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, language, timestamp ) # Submit the task coro_handles += [ asyncio.create_task(__agenerate__(request_id, prompt, params)) ] # Wait for all the segment to complete text_chunks = await asyncio.gather(*coro_handles) # if not is_cancelled.cancel_called: segments, text = await asyncio.get_event_loop().run_in_executor( None, process_chunks, tokenizer, text_chunks, request ) return segments, text async def __call__( self, request: TranscriptionRequest, ctx: Context ) -> TranscriptionResponse: with logger.contextualize(request_id=ctx.request_id): with memoryview(request) as audio: # Check if we need to enable the verbose path is_verbose = ( request.response_kind == TranscriptionResponseKind.VERBOSE_JSON ) # Retrieve the tokenizer and model config asynchronously while we decode audio tokenizer = asyncio.create_task(self._engine.get_tokenizer()) model_config = asyncio.create_task(self._engine.get_model_config()) # Decode audio from librosa (for now) # TODO: Use native (Rust provided) decoding (waveform, sampling) = load_audio(BytesIO(audio), sr=22050, mono=True) logger.debug( f"Successfully decoded {len(waveform)} bytes PCM audio chunk" ) # Create parameters max_tokens = (await model_config).max_model_len - 4 params = create_params(max_tokens, request.temperature, is_verbose) # Chunk audio in pieces audio_chunks = chunk_audio_with_duration( waveform, maximum_duration_sec=WhisperHandler.WHISPER_SEGMENT_DURATION_SEC, sampling_rate=WhisperHandler.WHISPER_SAMPLING_RATE, ) # Submit audio pieces to the batcher and gather them all segments, text = await self.transcribe( ctx, request, await tokenizer, audio_chunks, params ) match request.response_kind: case TranscriptionResponseKind.VERBOSE_JSON: return TranscriptionResponse.verbose( VerboseTranscription( text=text, duration=get_duration(y=waveform, sr=sampling), language=request.language, segments=segments, # word=None ) ) case TranscriptionResponseKind.JSON: return TranscriptionResponse.json(text) case TranscriptionResponseKind.TEXT: return TranscriptionResponse.text(text) # I don't forsee any case this would happen but at least we are safe raise ValueError(f"Invalid response_kind ({request.response_kind})") def entrypoint(): # Retrieve endpoint configuration endpoint_config = EndpointConfig.from_env() # Ensure the model is compatible is pre-downloaded if (model_local_path := Path(endpoint_config.model_id)).exists(): if (config_local_path := (model_local_path / "config.json")).exists(): config = AutoConfig.from_pretrained(config_local_path) ensure_supported_architectures(config, SUPPORTED_MODEL_ARCHITECTURES) # Initialize the endpoint endpoint = AutomaticSpeechRecognitionEndpoint( WhisperHandler(endpoint_config.model_id) ) # Serve the model run(endpoint, endpoint_config.interface, endpoint_config.port) if __name__ == "__main__": entrypoint()