thliang01 commited on
Commit
5fa88f4
·
verified ·
1 Parent(s): cfc8b04

Update app.py

Browse files

* 設定pad_token
* 使用return_dict=True獲取attention_mask
* 在generate_kwargs中加入必要參數
* 增加錯誤處理和記憶體清理
* 增加streamer timeout

Files changed (1) hide show
  1. app.py +57 -44
app.py CHANGED
@@ -1,18 +1,13 @@
1
  import gradio as gr
2
  import spaces
3
  import os
4
- import spaces
5
  import torch
6
- from transformers import GemmaTokenizer, AutoModelForCausalLM
7
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
8
  from threading import Thread
9
 
10
  # Set an environment variable
11
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
12
 
13
- zero = torch.Tensor([0]).cuda()
14
- print(zero.device) # <-- 'cpu' 🤔
15
-
16
  DESCRIPTION = '''
17
  <div>
18
  <h1 style="text-align: center;">TAIDE/Llama3-TAIDE-LX-8B-Chat-Alpha1</h1>
@@ -41,7 +36,12 @@ h1 {
41
 
42
  # Load the tokenizer and model
43
  tokenizer = AutoTokenizer.from_pretrained("taide/Llama3-TAIDE-LX-8B-Chat-Alpha1")
44
- model = AutoModelForCausalLM.from_pretrained("taide/Llama3-TAIDE-LX-8B-Chat-Alpha1") # to("cuda:0")
 
 
 
 
 
45
  terminators = [
46
  tokenizer.eos_token_id,
47
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
@@ -55,47 +55,60 @@ def chat_taide_8b(message: str,
55
  ) -> str:
56
  """
57
  Generate a streaming response using the llama3-8b model.
58
- Args:
59
- message (str): The input message.
60
- history (list): The conversation history used by ChatInterface.
61
- temperature (float): The temperature for generating the response.
62
- max_new_tokens (int): The maximum number of new tokens to generate.
63
- Returns:
64
- str: The generated response.
65
  """
66
- conversation = []
67
- for user, assistant in history:
68
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
69
- conversation.append({"role": "user", "content": message})
70
-
71
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(model.device)
72
-
73
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
74
 
75
- generate_kwargs = dict(
76
- input_ids= input_ids,
77
- streamer=streamer,
78
- max_new_tokens=max_new_tokens,
79
- do_sample=True,
80
- temperature=temperature,
81
- eos_token_id=terminators,
82
- )
83
- # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
84
- if temperature == 0:
85
- generate_kwargs['do_sample'] = False
 
86
 
87
- t = Thread(target=model.generate, kwargs=generate_kwargs)
88
- t.start()
89
 
90
- outputs = []
91
- for text in streamer:
92
- outputs.append(text)
93
- #print(outputs)
94
- yield "".join(outputs)
 
 
 
 
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  # Gradio block
98
- chatbot=gr.Chatbot(height=450, label='Gradio ChatInterface')
99
 
100
  with gr.Blocks(fill_height=True, css=css) as demo:
101
 
@@ -118,15 +131,15 @@ with gr.Blocks(fill_height=True, css=css) as demo:
118
  step=1,
119
  value=512,
120
  label="Max new tokens",
121
- render=False ),
122
- ],
123
  examples=[
124
  ['請以以下內容為基礎,寫一篇文章:撰寫一篇作文,題目為《一張舊照片》,內容要求為:選擇一張令你印象深刻的照片,說明令你印象深刻的原因,並描述照片中的影像及背後的故事。記錄成長的過程、與他人的情景、環境變遷和美麗的景色。'],
125
  ['請以品牌經理的身份,給廣告公司的創意總監寫一封信,提出對於新產品廣告宣傳活動的創意建議。'],
126
  ['以下提供英文內容,請幫我翻譯成中文。Dongshan coffee is famous for its unique position, and the constant refinement of production methods. The flavor is admired by many caffeine afficionados.'],
127
- ],
128
  cache_examples=False,
129
- )
130
 
131
  gr.Markdown(LICENSE)
132
 
 
1
  import gradio as gr
2
  import spaces
3
  import os
 
4
  import torch
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
  from threading import Thread
7
 
8
  # Set an environment variable
9
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
 
 
 
 
11
  DESCRIPTION = '''
12
  <div>
13
  <h1 style="text-align: center;">TAIDE/Llama3-TAIDE-LX-8B-Chat-Alpha1</h1>
 
36
 
37
  # Load the tokenizer and model
38
  tokenizer = AutoTokenizer.from_pretrained("taide/Llama3-TAIDE-LX-8B-Chat-Alpha1")
39
+ model = AutoModelForCausalLM.from_pretrained("taide/Llama3-TAIDE-LX-8B-Chat-Alpha1")
40
+
41
+ # 設定pad_token_id(關鍵修正)
42
+ if tokenizer.pad_token is None:
43
+ tokenizer.pad_token = tokenizer.eos_token
44
+
45
  terminators = [
46
  tokenizer.eos_token_id,
47
  tokenizer.convert_tokens_to_ids("<|eot_id|>")
 
55
  ) -> str:
56
  """
57
  Generate a streaming response using the llama3-8b model.
 
 
 
 
 
 
 
58
  """
59
+ try:
60
+ conversation = []
61
+ for user, assistant in history:
62
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
63
+ conversation.append({"role": "user", "content": message})
 
 
 
64
 
65
+ # 使用return_dict=True來獲取attention_mask(關鍵修正)
66
+ inputs = tokenizer.apply_chat_template(
67
+ conversation,
68
+ return_tensors="pt",
69
+ return_dict=True,
70
+ add_generation_prompt=True
71
+ )
72
+
73
+ input_ids = inputs["input_ids"].to(model.device)
74
+ attention_mask = inputs.get("attention_mask", None)
75
+ if attention_mask is not None:
76
+ attention_mask = attention_mask.to(model.device)
77
 
78
+ streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
 
79
 
80
+ generate_kwargs = dict(
81
+ input_ids=input_ids,
82
+ attention_mask=attention_mask, # 加入attention_mask
83
+ streamer=streamer,
84
+ max_new_tokens=max_new_tokens,
85
+ do_sample=True,
86
+ temperature=temperature,
87
+ eos_token_id=terminators,
88
+ pad_token_id=tokenizer.pad_token_id, # 明確設定pad_token_id
89
+ )
90
 
91
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
92
+ if temperature == 0:
93
+ generate_kwargs['do_sample'] = False
94
+
95
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
96
+ t.start()
97
+
98
+ outputs = []
99
+ for text in streamer:
100
+ outputs.append(text)
101
+ yield "".join(outputs)
102
+
103
+ except Exception as e:
104
+ yield f"生成過程中發生錯誤: {str(e)}"
105
+ finally:
106
+ # 清理GPU記憶體
107
+ if torch.cuda.is_available():
108
+ torch.cuda.empty_cache()
109
 
110
  # Gradio block
111
+ chatbot = gr.Chatbot(height=450, label='Gradio ChatInterface')
112
 
113
  with gr.Blocks(fill_height=True, css=css) as demo:
114
 
 
131
  step=1,
132
  value=512,
133
  label="Max new tokens",
134
+ render=False),
135
+ ],
136
  examples=[
137
  ['請以以下內容為基礎,寫一篇文章:撰寫一篇作文,題目為《一張舊照片》,內容要求為:選擇一張令你印象深刻的照片,說明令你印象深刻的原因,並描述照片中的影像及背後的故事。記錄成長的過程、與他人的情景、環境變遷和美麗的景色。'],
138
  ['請以品牌經理的身份,給廣告公司的創意總監寫一封信,提出對於新產品廣告宣傳活動的創意建議。'],
139
  ['以下提供英文內容,請幫我翻譯成中文。Dongshan coffee is famous for its unique position, and the constant refinement of production methods. The flavor is admired by many caffeine afficionados.'],
140
+ ],
141
  cache_examples=False,
142
+ )
143
 
144
  gr.Markdown(LICENSE)
145