ericsorides commited on
Commit
dbfea30
·
verified ·
1 Parent(s): 89583f7

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +140 -1
README.md CHANGED
@@ -41,4 +41,143 @@ For more documentation on downloading with `huggingface-cli`, please see: [HF ->
41
 
42
  This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/).
43
 
44
- Scripts about how to run these models will be provided soon.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/).
43
 
44
+ Here is a sample script to run this models:
45
+
46
+ ```python
47
+ #!/usr/bin/env python3
48
+ import whisper
49
+ import onnx
50
+ import sys
51
+ import time
52
+ import onnxruntime
53
+ from typing import Sequence, Optional
54
+ import numpy as np
55
+ from pathlib import Path
56
+
57
+ def run_whisper_decoder(decoder_model_path, execution_provider, session_options, decoder_output_names, cross_attn_tensors, num_new_tokens, provider_options = {}):
58
+ start = time.time()
59
+ decoder_session = onnxruntime.InferenceSession(decoder_model_path, sess_options=session_options, providers=[execution_provider], provider_options=[provider_options])
60
+ compile_time = time.time()
61
+ transcription = decoder_loop(decoder_session, decoder_output_names, cross_attn_tensors, num_new_tokens)
62
+ inference_time = time.time()
63
+ return transcription
64
+
65
+
66
+ def decoder_loop(decoder_session, decoder_output_names, cross_attn_tensors, num_new_tokens):
67
+ # Generate start of transcription tokens
68
+ tokenizer = whisper.tokenizer.get_tokenizer(multilingual=True)
69
+ first_tokens = np.array([tokenizer.sot, 0, tokenizer.transcribe, tokenizer.no_timestamps], dtype=np.int64)
70
+
71
+ # Self attention mask key, value vectors
72
+ self_attn_past_k = []
73
+ self_attn_past_v = []
74
+ for i in range(32):
75
+ self_attn_past_k.append(np.zeros((1, 20, 447, 64), dtype=np.float16))
76
+ self_attn_past_v.append(np.zeros((1, 20, 447, 64), dtype=np.float16))
77
+
78
+ # Cross attention
79
+ cross_attn_k = cross_attn_tensors[0::2]
80
+ cross_attn_v = cross_attn_tensors[1::2]
81
+
82
+ # Attention mask
83
+ attn_mask_size = 448
84
+ attn_mask = np.zeros((1,attn_mask_size), dtype=np.int64)
85
+
86
+ # Process first tokens
87
+ for j in range(len(first_tokens)):
88
+ tokens = np.array([first_tokens[j]], dtype=np.int64).reshape(1, 1)
89
+ attn_mask[0,-1 - j] = 1
90
+
91
+ decoder_input = {"input_ids": tokens, "attention_mask": attn_mask}
92
+ for i in range(32):
93
+ decoder_input[f"past_key_values.{str(i)}.key"] = self_attn_past_k[i]
94
+ decoder_input[f"past_key_values.{str(i)}.value"] = self_attn_past_v[i]
95
+ decoder_input[f"cross_attn.{str(i)}.key"] = cross_attn_k[i]
96
+ decoder_input[f"cross_attn.{str(i)}.value"] = cross_attn_v[i]
97
+
98
+ logits, *cache_tensors = decoder_session.run(decoder_output_names, decoder_input)
99
+ next_token = np.argmax(logits[0,0])
100
+
101
+ self_attn_k = cache_tensors[0::2]
102
+ self_attn_v = cache_tensors[1::2]
103
+ for i in range(32):
104
+ self_attn_past_k[i] = self_attn_k[i][:,:,1:,:]
105
+ self_attn_past_v[i] = self_attn_v[i][:,:,1:,:]
106
+
107
+ if (j == 0):
108
+ # set language token
109
+ first_tokens[1] = next_token
110
+
111
+ transcribed_tokens = [next_token]
112
+ for j in range(4, 4 + num_new_tokens):
113
+ tokens = np.array([transcribed_tokens[-1]], dtype=np.int64).reshape(1, 1)
114
+ attn_mask[0,-1 - j] = 1
115
+
116
+ decoder_input = {"input_ids": tokens, "attention_mask": attn_mask}
117
+ for i in range(32):
118
+ decoder_input[f"past_key_values.{str(i)}.key"] = self_attn_past_k[i]
119
+ decoder_input[f"past_key_values.{str(i)}.value"] = self_attn_past_v[i]
120
+ decoder_input[f"cross_attn.{str(i)}.key"] = cross_attn_k[i]
121
+ decoder_input[f"cross_attn.{str(i)}.value"] = cross_attn_v[i]
122
+
123
+ logits, *cache_tensors = decoder_session.run(decoder_output_names, decoder_input)
124
+ next_token = np.argmax(logits[0,0])
125
+ # print(j, next_token)
126
+ if next_token == tokenizer.eot: # end_of_transcription
127
+ break
128
+ transcribed_tokens.append(next_token)
129
+ self_attn_k = cache_tensors[0::2]
130
+ self_attn_v = cache_tensors[1::2]
131
+ for i in range(32):
132
+ self_attn_past_k[i] = self_attn_k[i][:,:,1:,:]
133
+ self_attn_past_v[i] = self_attn_v[i][:,:,1:,:]
134
+
135
+ return tokenizer.decode(transcribed_tokens)
136
+
137
+
138
+ def main(argv: Optional[Sequence[str]] = None):
139
+ num_seconds = 28.8
140
+
141
+ speech_path = 'sample_audio.wav'
142
+ encoder_model_path = 'whisper-large-v3-kvc-fp16-onnx/encoder/model.onnx'
143
+ decoder_model_path = 'whisper-large-v3-kvc-fp16-onnx/decoder/model.onnx'
144
+
145
+ # Load audio
146
+ print(f"Spectrogram speech audio file {speech_path}... ", end="")
147
+ audio = whisper.load_audio(speech_path)
148
+ audio = whisper.pad_or_trim(audio, length=int(num_seconds*16000))
149
+ mel = whisper.log_mel_spectrogram(audio, n_mels=128).unsqueeze(0) # Unsqueeze to set batch=1
150
+ print("OK")
151
+
152
+ print("Running encoder... ", end="")
153
+
154
+ # Session options
155
+ session_options = onnxruntime.SessionOptions()
156
+ # Disable all the graph optimizations
157
+ session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
158
+
159
+ # Encode
160
+ encoder = onnx.load(encoder_model_path, load_external_data=False)
161
+ encoder_input = {"mel": mel.numpy().astype('float16')}
162
+ encoder_output_names = [tensor.name for tensor in encoder.graph.output]
163
+ # CPU encoding
164
+ cpu_provider = 'CPUExecutionProvider'
165
+ enc_session_cpu = onnxruntime.InferenceSession(encoder_model_path, sess_options=session_options, providers=[cpu_provider])
166
+ cross_attn_tensors_cpu = enc_session_cpu.run(encoder_output_names, encoder_input)
167
+
168
+ print("OK")
169
+
170
+ # DECODE API PARAMS
171
+ max_context = 448
172
+ new_tokens = 20
173
+
174
+ # Run decoder model CPU
175
+ decoder = onnx.load(decoder_model_path, load_external_data=False)
176
+ decoder_output_names = [tensor.name for tensor in decoder.graph.output]
177
+
178
+ run_whisper_decoder(decoder_model_path, cpu_provider, session_options, decoder_output_names, cross_attn_tensors_cpu, new_tokens)
179
+
180
+
181
+ if __name__ == "__main__":
182
+ sys.exit(main(sys.argv[1:]))
183
+ ```