wiklif commited on
Commit
8475fdd
·
1 Parent(s): f7fc778
Files changed (1) hide show
  1. app.py +38 -18
app.py CHANGED
@@ -2,29 +2,49 @@ import os
2
  import spaces
3
  import gradio as gr
4
  import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
 
6
 
7
  model_id = "meta-llama/Meta-Llama-3.1-8B"
 
8
 
9
- @spaces.GPU(duration=60)
10
- def load_model():
11
- tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("MY_API_LLAMA_3_1"))
12
- model = AutoModelForCausalLM.from_pretrained(
13
- model_id,
14
- token=os.environ.get("MY_API_LLAMA_3_1"),
15
- torch_dtype=torch.bfloat16,
16
- device_map="auto",
17
- low_cpu_mem_usage=True
18
- )
19
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
20
 
21
- pipe = load_model()
 
 
 
 
 
 
 
 
 
 
22
 
23
  @spaces.GPU(duration=60)
24
  def generate_response(chat, kwargs):
25
- output = pipe(chat, **kwargs)[0]['generated_text']
26
- if output.endswith("</s>"):
27
- output = output[:-4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  return output
29
 
30
  def function(prompt, history=[]):
@@ -33,11 +53,11 @@ def function(prompt, history=[]):
33
  chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> <s>"
34
  chat += f"[INST] {prompt} [/INST]"
35
  kwargs = dict(
36
- temperature=0.5,
37
  max_new_tokens=4096,
 
 
38
  top_p=0.95,
39
  repetition_penalty=1.0,
40
- do_sample=True,
41
  seed=1337
42
  )
43
 
 
2
  import spaces
3
  import gradio as gr
4
  import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
6
+ from threading import Thread
7
+ from queue import Queue
8
 
9
  model_id = "meta-llama/Meta-Llama-3.1-8B"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.environ.get("MY_API_LLAMA_3_1"))
11
 
12
+ model = None
13
+ model_load_queue = Queue()
 
 
 
 
 
 
 
 
 
14
 
15
+ def load_model():
16
+ global model
17
+ if model is None:
18
+ model = AutoModelForCausalLM.from_pretrained(
19
+ model_id,
20
+ token=os.environ.get("MY_API_LLAMA_3_1"),
21
+ torch_dtype=torch.bfloat16,
22
+ device_map="auto",
23
+ low_cpu_mem_usage=True
24
+ )
25
+ model_load_queue.put(model)
26
 
27
  @spaces.GPU(duration=60)
28
  def generate_response(chat, kwargs):
29
+ global model
30
+ if model is None:
31
+ Thread(target=load_model).start()
32
+ model = model_load_queue.get()
33
+
34
+ inputs = tokenizer(chat, return_tensors="pt").to(model.device)
35
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
36
+
37
+ generation_kwargs = dict(inputs, streamer=streamer, **kwargs)
38
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
39
+ thread.start()
40
+
41
+ output = ""
42
+ for new_text in streamer:
43
+ output += new_text
44
+ if output.endswith("</s>"):
45
+ output = output[:-4]
46
+ break
47
+
48
  return output
49
 
50
  def function(prompt, history=[]):
 
53
  chat += f"[INST] {user_prompt} [/INST] {bot_response}</s> <s>"
54
  chat += f"[INST] {prompt} [/INST]"
55
  kwargs = dict(
 
56
  max_new_tokens=4096,
57
+ do_sample=True,
58
+ temperature=0.5,
59
  top_p=0.95,
60
  repetition_penalty=1.0,
 
61
  seed=1337
62
  )
63