whisper-vllm-gpu / handler.py
Morgan Funtowicz
misc(whisper): minor
1b7eead
import asyncio
import zlib
from functools import lru_cache
from io import BytesIO
from pathlib import Path
from typing import Sequence, List, Tuple, Generator, Iterable, TYPE_CHECKING
import numpy as np
from hfendpoints.errors.config import UnsupportedModelArchitecture
from hfendpoints.openai import Context, run
from hfendpoints.openai.audio import (
AutomaticSpeechRecognitionEndpoint,
TranscriptionRequest,
TranscriptionResponse,
TranscriptionResponseKind,
SegmentBuilder,
Segment,
Transcription,
VerboseTranscription,
)
from librosa import load as load_audio, get_duration
from loguru import logger
from transformers import AutoConfig
from vllm import (
AsyncEngineArgs,
AsyncLLMEngine,
SamplingParams,
)
from hfendpoints import EndpointConfig, Handler, ensure_supported_architectures
if TYPE_CHECKING:
from transformers import PreTrainedTokenizer
from vllm import CompletionOutput, RequestOutput
from vllm.sequence import SampleLogprobs
SUPPORTED_MODEL_ARCHITECTURES = ["WhisperForConditionalGeneration"]
def chunk_audio_with_duration(
audio: np.ndarray, maximum_duration_sec: int, sampling_rate: int
) -> Sequence[np.ndarray]:
"""
Chunk a mono audio timeseries so that each chunk is as long as `maximum_duration_sec`.
Chunks are evenly distributed except the last one which might be shorter
:param audio: The mono timeseries waveform of the audio
:param maximum_duration_sec: The maximum length, in seconds, for each chunk
:param sampling_rate: The number of samples to represent one second of audio
:return: List of numpy array representing the chunk
"""
# We pad the input so that every chunk length is `max_duration_sec`
max_duration_samples = sampling_rate * maximum_duration_sec
padding = max_duration_samples - np.remainder(len(audio), max_duration_samples)
audio = np.pad(audio, (0, padding), constant_values=0.0)
return np.split(audio, len(audio) // max_duration_samples)
def compression_ratio(text: str) -> float:
"""
:param text:
:return:
"""
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def create_prompt(
audio: np.ndarray,
sampling_rate: int,
language: int,
timestamp_marker: int,
):
"""
Generate the right prompt with the specific parameters to submit for inference over Whisper
:param audio: PCM data containing audio signal representation
:param sampling_rate: Number of samples in one second of audio
:param language: Token id representing the language of the audio content
:param timestamp_marker: Token id representing the temporal position within the audio content for this segment
:return: Dictionary with all the prefilled value to call `generate`
"""
return {
"encoder_prompt": {
"prompt": "",
"multi_modal_data": {"audio": (audio, sampling_rate)},
},
"decoder_prompt": {
"prompt_token_ids": [
50258,
language,
50360,
timestamp_marker,
]
},
}
def create_params(
max_tokens: int, temperature: float, is_verbose: bool
) -> "SamplingParams":
"""
Create sampling parameters to submit for inference through vLLM `generate`
:param max_tokens: Maximum number of tokens to generate
:param temperature: Sampling temperature for the softmax
:param is_verbose: Flag indicating whether the response is required to contains verbose output
:return: `SamplingParams`
"""
return SamplingParams.from_optional(
# output_kind=RequestOutputKind.FINAL_ONLY, # Change if streaming
max_tokens=max_tokens,
skip_special_tokens=False,
detokenize=False,
temperature=temperature,
logprobs=1 if is_verbose else None,
)
def get_avg_logprob(logprobs: "SampleLogprobs") -> float:
"""
Aggregate the log probabilities over all generation steps by taking the log probability of the generated token
:param logprobs: Iterable of log probabilities for all the generation steps
:return: Averaged log probability as floating-point number
"""
sum_logp = sum(next(iter(_step_.values())).logprob for _step_ in logprobs)
return sum_logp / float(len(logprobs))
def process_chunk(
tokenizer: "PreTrainedTokenizer",
ids: np.ndarray,
logprobs: "SampleLogprobs",
request: TranscriptionRequest,
segment_offset: int,
timestamp_offset: int,
) -> Generator:
"""
Decode a single transcribed audio chunk and generates all the segments associated
:param tokenizer:
:param ids:
:param logprobs:
:param request:
:param segment_offset:
:param timestamp_offset:
:return:
"""
# Some constants
k_timestamp_token = lru_cache(tokenizer.convert_tokens_to_ids)(f"<|0.00|>")
# Detect start of transcript token
# sot_mask = ids == k_sot_token
# Timestamps are expected to have ids greater than token_id(<|0.00|>)
# We create a mask for all the potential tokens which are >= token_id(<|0.00|>)
timestamps_mask = ids >= k_timestamp_token
if np.any(timestamps_mask):
# If we have a timestamp token, we need to check whether it's a final token or a final then the next
is_single_ending_timestamp = np.array_equal(timestamps_mask[-2:], [False, True])
# Iterate over timestamps
timestamp_start, timestamp_end = 0.0, 0.0
slice_start = 0
for t, position in enumerate(np.flatnonzero(timestamps_mask)):
timestamp = float(tokenizer.convert_ids_to_tokens([ids[position]])[0][2:-2])
if t % 2 == 0:
timestamp_end = timestamp
# Retrieve segment info
segment_ids = ids[slice_start:position]
segment_text = tokenizer.decode(segment_ids)
# Compute the avg_logprob
avg_logprob = get_avg_logprob(logprobs) if logprobs else float("nan")
# no-speech logprob
# no_speech_token_id = lru_cache(tokenizer.convert_tokens_to_ids("<|nospeech|>"))
# no_speech_logprob = logprobs[no_speech_token_id]
# Materialize the segment in memory
segment = (
SegmentBuilder()
.id(segment_offset + t)
.start(timestamp_offset + timestamp_start)
.end(timestamp_offset + timestamp_end)
.text(segment_text)
.tokens(segment_ids.tolist())
.temperature(request.temperature)
.avg_logprob(avg_logprob)
.compression_ratio(compression_ratio(segment_text))
.build()
)
yield segment, is_single_ending_timestamp
# Update the start position
slice_start = position
else:
timestamp_start = timestamp
def process_chunks(
tokenizer: "PreTrainedTokenizer",
chunks: List["RequestOutput"],
request: TranscriptionRequest,
) -> Tuple[List[Segment], str]:
"""
Iterate over all the audio chunk's outputs and consolidates outputs as segment(s) whether the response is verbose or not
:param tokenizer: The tokenizer to use for decoding tokens
:param chunks: Transcribed audio chunks
:param request: Received request from the user
:return: `Tuple[List[Segment], str]` holding all the consolidated segments along with full transcribed text
"""
# k_nospeech_token = tokenizer.convert_tokens_to_ids("<|nospeech|>")
# k_sot_token = tokenizer.convert_tokens_to_ids("<|startoftranscript|>")
materialized_segments, materialized_segments_tokens_acc = [], []
# Iterate over segments
for idx, chunk in enumerate(chunks):
time_offset = idx * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC
segment_offset = len(materialized_segments)
generation: "CompletionOutput" = chunk.outputs[-1]
ids: np.ndarray = np.asarray(generation.token_ids)
logprobs = generation.logprobs
for segment, _is_continuation in process_chunk(
tokenizer, ids, logprobs, request, segment_offset, time_offset
):
materialized_segments.append(segment)
# Accumulate the tokens for full decoding
materialized_segments_tokens_acc += generation.token_ids
text = tokenizer.decode(
materialized_segments_tokens_acc,
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
return materialized_segments, text
class WhisperHandler(Handler[TranscriptionRequest, TranscriptionResponse]):
WHISPER_SEGMENT_DURATION_SEC = 30
WHISPER_SAMPLING_RATE = 22050
__slots__ = ("_engine",)
def __init__(self, model_id_or_path: str):
super().__init__(model_id_or_path)
self._engine = AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(
model_id_or_path,
task="transcription",
device="auto",
dtype="bfloat16",
kv_cache_dtype="fp8",
enforce_eager=False,
enable_prefix_caching=True,
max_logprobs=1, # TODO(mfuntowicz) : Set from config?
disable_log_requests=True,
)
)
async def transcribe(
self,
ctx: Context,
request: TranscriptionRequest,
tokenizer: "PreTrainedTokenizer",
audio_chunks: Iterable[np.ndarray],
params: "SamplingParams",
) -> (List[Segment], str):
async def __agenerate__(request_id: str, prompt, params):
"""
Helper method to unroll asynchronous generator and return the last element
:param request_id: Unique identifier for this request
:param prompt: The prompt to submit for inference on vLLM through `generate(...)`
:param params: The parameters passed along with the prompt for inference on vLLM through `generate(...)`
:return: `CompletionOutput`
"""
# Submit for inference on the segment & keep track of the background task
async for step in self._engine.generate(prompt, params, request_id):
pass
return step
# Wrap tokenizer results with LRU cache to avoid vocabulary lookup
convert_tokens_to_ids = lru_cache(tokenizer.convert_tokens_to_ids)
coro_handles = []
for audio_chunk_id, audio_chunk in enumerate(audio_chunks):
# Generate suffixed request-id to submit and identify through vLLM scheduler
request_id = f"{ctx.request_id}-{audio_chunk_id}"
# Compute the starting time of the chunk as each consecutive chunk represents 30s worth of audio
timestamp = audio_chunk_id * WhisperHandler.WHISPER_SEGMENT_DURATION_SEC
# Compute initial prompt for the segment
is_verbose = request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
language = convert_tokens_to_ids(f"<|{request.language}|>")
timestamp = convert_tokens_to_ids(
f"<|0.00|>" if is_verbose else "<|notimestamps|>"
)
prompt = create_prompt(
audio_chunk, WhisperHandler.WHISPER_SAMPLING_RATE, language, timestamp
)
# Submit the task
coro_handles += [
asyncio.create_task(__agenerate__(request_id, prompt, params))
]
# Wait for all the segment to complete
text_chunks = await asyncio.gather(*coro_handles)
# if not is_cancelled.cancel_called:
segments, text = await asyncio.get_event_loop().run_in_executor(
None, process_chunks, tokenizer, text_chunks, request
)
return segments, text
async def __call__(
self, request: TranscriptionRequest, ctx: Context
) -> TranscriptionResponse:
with logger.contextualize(request_id=ctx.request_id):
with memoryview(request) as audio:
# Check if we need to enable the verbose path
is_verbose = (
request.response_kind == TranscriptionResponseKind.VERBOSE_JSON
)
# Retrieve the tokenizer and model config asynchronously while we decode audio
tokenizer = asyncio.create_task(self._engine.get_tokenizer())
model_config = asyncio.create_task(self._engine.get_model_config())
# Decode audio from librosa (for now)
# TODO: Use native (Rust provided) decoding
(waveform, sampling) = load_audio(BytesIO(audio), sr=22050, mono=True)
logger.debug(
f"Successfully decoded {len(waveform)} bytes PCM audio chunk"
)
# Create parameters
max_tokens = (await model_config).max_model_len - 4
params = create_params(max_tokens, request.temperature, is_verbose)
# Chunk audio in pieces
audio_chunks = chunk_audio_with_duration(
waveform,
maximum_duration_sec=WhisperHandler.WHISPER_SEGMENT_DURATION_SEC,
sampling_rate=WhisperHandler.WHISPER_SAMPLING_RATE,
)
# Submit audio pieces to the batcher and gather them all
segments, text = await self.transcribe(
ctx, request, await tokenizer, audio_chunks, params
)
match request.response_kind:
case TranscriptionResponseKind.VERBOSE_JSON:
return TranscriptionResponse.verbose(
VerboseTranscription(
text=text,
duration=get_duration(y=waveform, sr=sampling),
language=request.language,
segments=segments,
# word=None
)
)
case TranscriptionResponseKind.JSON:
return TranscriptionResponse.json(text)
case TranscriptionResponseKind.TEXT:
return TranscriptionResponse.text(text)
# I don't forsee any case this would happen but at least we are safe
raise ValueError(f"Invalid response_kind ({request.response_kind})")
def entrypoint():
# Retrieve endpoint configuration
endpoint_config = EndpointConfig.from_env()
# Ensure the model is compatible is pre-downloaded
if (model_local_path := Path(endpoint_config.model_id)).exists():
if (config_local_path := (model_local_path / "config.json")).exists():
config = AutoConfig.from_pretrained(config_local_path)
ensure_supported_architectures(config, SUPPORTED_MODEL_ARCHITECTURES)
# Initialize the endpoint
endpoint = AutomaticSpeechRecognitionEndpoint(
WhisperHandler(endpoint_config.model_id)
)
# Serve the model
run(endpoint, endpoint_config.interface, endpoint_config.port)
if __name__ == "__main__":
entrypoint()