seba commited on
Commit
597c8e2
·
verified ·
1 Parent(s): bf01923

Upload falcon_edge_generate.py

Browse files
Files changed (1) hide show
  1. falcon_edge_generate.py +80 -45
falcon_edge_generate.py CHANGED
@@ -5,6 +5,7 @@ import time
5
  from transformers import AutoTokenizer
6
  import shutil
7
  from argparse import ArgumentParser
 
8
 
9
 
10
  def copy_compiled_model(mlmodel: ct.models.MLModel, dest: str):
@@ -35,6 +36,30 @@ def load_embeddings(path):
35
  return np.load(path)
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  class ModelContainer:
39
  def __init__(
40
  self,
@@ -73,13 +98,11 @@ class ModelContainer:
73
  )
74
  self.tokenizer = AutoTokenizer.from_pretrained(hf_model)
75
  self.end_of_response_token_id = self.tokenizer("<|im_end|>").input_ids[0]
 
 
76
 
77
  self.state = None
78
  self.position = None
79
- self.attention_mask = None
80
-
81
- def initialize_generation(self):
82
- self.state = self.generation_model.make_state()
83
  attention_mask = np.arange(self.cache_length, dtype=np.int32)
84
  attention_mask = attention_mask[:, None] >= attention_mask[None, :]
85
  attention_mask = attention_mask[None, None, :, :]
@@ -88,6 +111,9 @@ class ModelContainer:
88
  np.array(0.0, dtype=np.float16),
89
  np.array(-np.inf, dtype=np.float16),
90
  )
 
 
 
91
  self.position = 0
92
 
93
  def load_prompt_model(self):
@@ -156,7 +182,7 @@ class ModelContainer:
156
  self.unload_prompt_model()
157
  end_time = time.perf_counter()
158
  print(
159
- f"==== Processed {processed_chunks * 64} tokens in {end_time - start_time:.2f} seconds, {processed_chunks * 64 / (end_time - start_time):.2f} tokens per second, current position: {self.position}",
160
  )
161
  if stop_processing:
162
  return np.array([-1], dtype=np.int32)
@@ -183,60 +209,69 @@ class ModelContainer:
183
  ][:, 0]
184
  return input_id
185
 
186
- def generate(self, input_id: np.array):
187
- stop_generation = False
188
  # for i in range(max_new_tokens):
189
- start_time = time.perf_counter()
190
  generated_tokens = 0
191
- while self.position < self.cache_length:
192
- generated_tokens += 1
193
- embd = self.embed(input_id).transpose(0, 3, 1, 2)
194
- hidden_states = self.generation_model.predict(
195
- {
196
- "hidden_states": embd,
197
- "kv_write_idx": np.array([self.position], dtype=np.int32),
198
- "positions": np.array([[self.position]], dtype=np.int32),
199
- "attention_mask": self.attention_mask[:, :, [self.position]],
200
- },
201
- self.state,
202
- )["output_hidden_states"]
203
- if stop_generation:
204
- print()
205
- # print("Loading prompt model...")
206
- self.position += 1
207
- break
208
-
209
- input_id = self.lm_head(hidden_states)
210
 
 
 
211
  input_id_item = input_id.item()
212
- if input_id_item == self.end_of_response_token_id:
213
- stop_generation = True
214
- print(self.tokenizer.decode(input_id_item), end="", flush=True)
 
 
 
 
 
 
 
 
 
 
215
  self.position += 1
 
 
 
 
216
 
217
  end_time = time.perf_counter()
218
  print(
219
- f"==== Generated {generated_tokens} tokens in {end_time - start_time:.2f} seconds, {generated_tokens / (end_time - start_time):.2f} tokens per second, current position: {self.position}",
220
  )
221
  # if stop_generation:
222
  # self.load_prompt_model()
223
 
224
  def loop(self):
225
- self.initialize_generation()
226
- print("Begin conversation...")
227
  while True:
228
- print(">>> ", end="", flush=True)
229
- self.load_prompt_model()
230
- prompt = input()
231
- prompt_result = self.process_prompt(prompt)
232
- if prompt_result.item() == -1:
233
- print("\n--- END OF CONVERSATION: MAX CONTEXT LENGTH REACHED ---\n")
234
- break
235
- print(self.tokenizer.decode(prompt_result.item()), end="", flush=True)
236
- self.generate(prompt_result)
237
- if self.position >= (self.cache_length):
238
- print("\n--- END OF CONVERSATION: MAX CONTEXT LENGTH REACHED ---\n")
239
- break
 
 
 
 
240
 
241
 
242
  def parse_args():
 
5
  from transformers import AutoTokenizer
6
  import shutil
7
  from argparse import ArgumentParser
8
+ import asyncio
9
 
10
 
11
  def copy_compiled_model(mlmodel: ct.models.MLModel, dest: str):
 
36
  return np.load(path)
37
 
38
 
39
+ async def generate_single_step(
40
+ input_id,
41
+ embed_fn,
42
+ model,
43
+ state,
44
+ position,
45
+ attention_mask_ref,
46
+ lm_head,
47
+ ):
48
+ embd = embed_fn(input_id).transpose(0, 3, 1, 2)
49
+ hidden_states = model.predict(
50
+ {
51
+ "hidden_states": embd,
52
+ "kv_write_idx": np.array([position], dtype=np.int32),
53
+ "positions": np.array([[position]], dtype=np.int32),
54
+ "attention_mask": attention_mask_ref[:, :, [position]],
55
+ },
56
+ state,
57
+ )["output_hidden_states"]
58
+ if lm_head is not None:
59
+ input_id = lm_head(hidden_states)
60
+ return input_id
61
+
62
+
63
  class ModelContainer:
64
  def __init__(
65
  self,
 
98
  )
99
  self.tokenizer = AutoTokenizer.from_pretrained(hf_model)
100
  self.end_of_response_token_id = self.tokenizer("<|im_end|>").input_ids[0]
101
+ self.end_of_text_token_id = self.tokenizer("<|end_of_text|>").input_ids[0]
102
+ self.break_tokens = [self.end_of_response_token_id, self.end_of_text_token_id]
103
 
104
  self.state = None
105
  self.position = None
 
 
 
 
106
  attention_mask = np.arange(self.cache_length, dtype=np.int32)
107
  attention_mask = attention_mask[:, None] >= attention_mask[None, :]
108
  attention_mask = attention_mask[None, None, :, :]
 
111
  np.array(0.0, dtype=np.float16),
112
  np.array(-np.inf, dtype=np.float16),
113
  )
114
+
115
+ def initialize_generation(self):
116
+ self.state = self.generation_model.make_state()
117
  self.position = 0
118
 
119
  def load_prompt_model(self):
 
182
  self.unload_prompt_model()
183
  end_time = time.perf_counter()
184
  print(
185
+ f"==== Processed {len(tokens)} tokens + {64 - len(chunk)} pad tokens in {end_time - start_time:.2f} seconds, {processed_chunks * 64 / (end_time - start_time):.2f} tokens per second, current position: {self.position}/{self.cache_length}",
186
  )
187
  if stop_processing:
188
  return np.array([-1], dtype=np.int32)
 
209
  ][:, 0]
210
  return input_id
211
 
212
+ async def generate(self, input_id: np.array):
213
+ continue_generating = True
214
  # for i in range(max_new_tokens):
 
215
  generated_tokens = 0
216
+ start_time = time.perf_counter()
217
+ # task = asyncio.create_task(generate_single_step(
218
+ # input_id,
219
+ # self.embed,
220
+ # self.generation_model,
221
+ # self.state,
222
+ # self.position,
223
+ # self.attention_mask,
224
+ # self.lm_head,
225
+ # ))
 
 
 
 
 
 
 
 
 
226
 
227
+ while (self.position < self.cache_length) and continue_generating:
228
+ generated_tokens += 1
229
  input_id_item = input_id.item()
230
+ if input_id_item in self.break_tokens:
231
+ continue_generating = False
232
+ task = asyncio.create_task(
233
+ generate_single_step(
234
+ input_id,
235
+ self.embed,
236
+ self.generation_model,
237
+ self.state,
238
+ self.position,
239
+ self.attention_mask,
240
+ self.lm_head if continue_generating else None,
241
+ )
242
+ )
243
  self.position += 1
244
+ print(self.tokenizer.decode(input_id_item), end="", flush=True)
245
+ input_id = await task
246
+
247
+ print()
248
 
249
  end_time = time.perf_counter()
250
  print(
251
+ f"==== Generated {generated_tokens} tokens in {end_time - start_time:.2f} seconds, {generated_tokens / (end_time - start_time):.2f} tokens per second, current position: {self.position}/{self.cache_length}",
252
  )
253
  # if stop_generation:
254
  # self.load_prompt_model()
255
 
256
  def loop(self):
257
+ print("--- Begin conversation ---")
 
258
  while True:
259
+ self.initialize_generation()
260
+ while True:
261
+ print(">>> ", end="", flush=True)
262
+ self.load_prompt_model()
263
+ prompt = input()
264
+ prompt_result = self.process_prompt(prompt)
265
+ if prompt_result.item() == -1:
266
+ print("\n--- END OF CONVERSATION: MAX CONTEXT LENGTH REACHED ---\n")
267
+ print("--- Beginning new conversation ---")
268
+ break
269
+ # print(self.tokenizer.decode(prompt_result.item()), end="", flush=True)
270
+ asyncio.run(self.generate(prompt_result))
271
+ if self.position >= (self.cache_length):
272
+ print("\n--- END OF CONVERSATION: MAX CONTEXT LENGTH REACHED ---\n")
273
+ print("--- Beginning new conversation ---")
274
+ break
275
 
276
 
277
  def parse_args():