--- tags: - text-generation-inference - whisper - audio base_model: - openai/whisper-large-v3 --- # Whisper Large v3 with Key-Value-Cache enabled in ONNX fp16 format - Model creator: [Open AI](https://huggingface.co/openai) - Original model: [Whisper Large v3](https://huggingface.co/openai/whisper-large-v3) ## Description This repo contains the ONNX files for the ONNX conversion of Whisper Large v3 done by Esperanto Technologies. The model is in the fp16 format and has the KVC enabled. ## How to download ONNX model and weight files The easiest way to obtain the model is to clone this whole repo. Alternatively you can download the files is using the `huggingface-hub` Python library. ```shell pip3 install huggingface-hub>=0.17.1 ``` Then you can download any individual model file to the current directory, at high speed, with a command like this: ```shell huggingface-cli download Esperanto/whisper-large-v3-kvc-fp16-onnx --local-dir whisper-large-v3-kvc-fp16-onnx --local-dir-use-symlinks False ``` For more documentation on downloading with `huggingface-cli`, please see: [HF -> Hub Python Library -> Download files -> Download from the CLI](https://huggingface.co/docs/huggingface_hub/guides/download#download-from-the-cli). ## How to run from Python code using ONNXRuntime This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/). Here is a sample script to run this models: ```python #!/usr/bin/env python3 import whisper import onnx import sys import time import onnxruntime from typing import Sequence, Optional import numpy as np from pathlib import Path def run_whisper_decoder(decoder_model_path, execution_provider, session_options, decoder_output_names, cross_attn_tensors, num_new_tokens, provider_options = {}): start = time.time() decoder_session = onnxruntime.InferenceSession(decoder_model_path, sess_options=session_options, providers=[execution_provider], provider_options=[provider_options]) compile_time = time.time() transcription = decoder_loop(decoder_session, decoder_output_names, cross_attn_tensors, num_new_tokens) inference_time = time.time() return transcription def decoder_loop(decoder_session, decoder_output_names, cross_attn_tensors, num_new_tokens): # Generate start of transcription tokens tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True) first_tokens = np.array([tokenizer.sot, 0, tokenizer.transcribe, tokenizer.no_timestamps], dtype=np.int64) # Self attention mask key, value vectors self_attn_past_k = [] self_attn_past_v = [] for i in range(32): self_attn_past_k.append(np.zeros((1, 20, 447, 64), dtype=np.float16)) self_attn_past_v.append(np.zeros((1, 20, 447, 64), dtype=np.float16)) # Cross attention cross_attn_k = cross_attn_tensors[0::2] cross_attn_v = cross_attn_tensors[1::2] # Attention mask attn_mask_size = 448 attn_mask = np.zeros((1,attn_mask_size), dtype=np.int64) # Process first tokens for j in range(len(first_tokens)): tokens = np.array([first_tokens[j]], dtype=np.int64).reshape(1, 1) attn_mask[0,-1 - j] = 1 decoder_input = {"input_ids": tokens, "attention_mask": attn_mask} for i in range(32): decoder_input[f"past_key_values.{str(i)}.key"] = self_attn_past_k[i] decoder_input[f"past_key_values.{str(i)}.value"] = self_attn_past_v[i] decoder_input[f"cross_attn.{str(i)}.key"] = cross_attn_k[i] decoder_input[f"cross_attn.{str(i)}.value"] = cross_attn_v[i] logits, *cache_tensors = decoder_session.run(decoder_output_names, decoder_input) next_token = np.argmax(logits[0,0]) self_attn_k = cache_tensors[0::2] self_attn_v = cache_tensors[1::2] for i in range(32): self_attn_past_k[i] = self_attn_k[i][:,:,1:,:] self_attn_past_v[i] = self_attn_v[i][:,:,1:,:] if (j == 0): # set language token first_tokens[1] = next_token transcribed_tokens = [next_token] for j in range(4, 4 + num_new_tokens): tokens = np.array([transcribed_tokens[-1]], dtype=np.int64).reshape(1, 1) attn_mask[0,-1 - j] = 1 decoder_input = {"input_ids": tokens, "attention_mask": attn_mask} for i in range(32): decoder_input[f"past_key_values.{str(i)}.key"] = self_attn_past_k[i] decoder_input[f"past_key_values.{str(i)}.value"] = self_attn_past_v[i] decoder_input[f"cross_attn.{str(i)}.key"] = cross_attn_k[i] decoder_input[f"cross_attn.{str(i)}.value"] = cross_attn_v[i] logits, *cache_tensors = decoder_session.run(decoder_output_names, decoder_input) next_token = np.argmax(logits[0,0]) # print(j, next_token) if next_token == tokenizer.eot: # end_of_transcription break transcribed_tokens.append(next_token) self_attn_k = cache_tensors[0::2] self_attn_v = cache_tensors[1::2] for i in range(32): self_attn_past_k[i] = self_attn_k[i][:,:,1:,:] self_attn_past_v[i] = self_attn_v[i][:,:,1:,:] return tokenizer.decode(transcribed_tokens) def main(argv: Optional[Sequence[str]] = None): num_seconds = 28.8 speech_path = 'sample_audio.wav' encoder_model_path = 'whisper-large-v3-kvc-fp16-onnx/encoder/model.onnx' decoder_model_path = 'whisper-large-v3-kvc-fp16-onnx/decoder/model.onnx' # Load audio print(f"Spectrogram speech audio file {speech_path}... ", end="") audio = whisper.load_audio(speech_path) audio = whisper.pad_or_trim(audio, length=int(num_seconds*16000)) mel = whisper.log_mel_spectrogram(audio, n_mels=128).unsqueeze(0) # Unsqueeze to set batch=1 print("OK") print("Running encoder... ", end="") # Session options session_options = onnxruntime.SessionOptions() # Disable all the graph optimizations session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL # Encode encoder = onnx.load(encoder_model_path, load_external_data=False) encoder_input = {"mel": mel.numpy().astype('float16')} encoder_output_names = [tensor.name for tensor in encoder.graph.output] # CPU encoding cpu_provider = 'CPUExecutionProvider' enc_session_cpu = onnxruntime.InferenceSession(encoder_model_path, sess_options=session_options, providers=[cpu_provider]) cross_attn_tensors_cpu = enc_session_cpu.run(encoder_output_names, encoder_input) print("OK") # DECODE API PARAMS max_context = 448 new_tokens = 20 # Run decoder model CPU decoder = onnx.load(decoder_model_path, load_external_data=False) decoder_output_names = [tensor.name for tensor in decoder.graph.output] run_whisper_decoder(decoder_model_path, cpu_provider, session_options, decoder_output_names, cross_attn_tensors_cpu, new_tokens) if __name__ == "__main__": sys.exit(main(sys.argv[1:])) ```