unausagi commited on
Commit
b06ae2b
·
verified ·
1 Parent(s): 3a18eee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -17
app.py CHANGED
@@ -1,9 +1,8 @@
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
3
  import torch
4
  import os
5
 
6
- # 預先定義 Hugging Face 模型
7
  MODEL_NAMES = {
8
  "DeepSeek-R1-Distill-Qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
9
  "DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
@@ -13,34 +12,27 @@ HF_TOKEN = os.getenv("HF_TOKEN")
13
 
14
  def load_model(model_path):
15
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
16
-
17
- # 先載入 config,手動刪除量化設定,防止 FP8 問題
18
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
19
  if hasattr(config, "quantization_config"):
20
  del config.quantization_config # 刪除量化配置,避免使用 FP8
21
 
22
  model = AutoModelForCausalLM.from_pretrained(
23
  model_path,
24
- config=config, # 使用已移除量化的 config
25
  trust_remote_code=True,
26
  token=HF_TOKEN,
27
- torch_dtype=torch.float16, # 強制 FP16,避免 FP8
28
  device_map="auto",
29
  )
30
  return model, tokenizer
31
 
32
-
33
- # 預設載入 DeepSeek-R1
34
  current_model, current_tokenizer = load_model("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
35
 
36
-
37
  def chat(message, history, model_name):
38
- """處理聊天訊息"""
39
  global current_model, current_tokenizer
40
 
41
- # 若模型不同則切換
42
  if model_name != current_model:
43
- current_model, current_tokenizer = load_model(model_name)
44
 
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
  inputs = current_tokenizer(message, return_tensors="pt").to(device)
@@ -49,15 +41,24 @@ def chat(message, history, model_name):
49
 
50
  return response
51
 
52
-
53
  with gr.Blocks() as app:
54
  gr.Markdown("## Chatbot with DeepSeek Models")
55
 
56
  with gr.Row():
57
- chat_interface = gr.ChatInterface(chat, type="messages", flagging_mode="manual", save_history=True)
 
 
 
 
 
58
  model_selector = gr.Dropdown(
59
- choices=list(MODEL_NAMES.keys()), value="DeepSeek-R1-Distill-Llama-8B", label="Select Model"
 
 
60
  )
61
 
62
- chat_interface.append(model_selector)
63
- app.launch()
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
3
  import torch
4
  import os
5
 
 
6
  MODEL_NAMES = {
7
  "DeepSeek-R1-Distill-Qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
8
  "DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
 
12
 
13
  def load_model(model_path):
14
  tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
 
 
15
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, token=HF_TOKEN)
16
  if hasattr(config, "quantization_config"):
17
  del config.quantization_config # 刪除量化配置,避免使用 FP8
18
 
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_path,
21
+ config=config,
22
  trust_remote_code=True,
23
  token=HF_TOKEN,
24
+ torch_dtype=torch.float16,
25
  device_map="auto",
26
  )
27
  return model, tokenizer
28
 
 
 
29
  current_model, current_tokenizer = load_model("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
30
 
 
31
  def chat(message, history, model_name):
 
32
  global current_model, current_tokenizer
33
 
 
34
  if model_name != current_model:
35
+ current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])
36
 
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
38
  inputs = current_tokenizer(message, return_tensors="pt").to(device)
 
41
 
42
  return response
43
 
 
44
  with gr.Blocks() as app:
45
  gr.Markdown("## Chatbot with DeepSeek Models")
46
 
47
  with gr.Row():
48
+ chat_interface = gr.ChatInterface(
49
+ chat,
50
+ type="messages",
51
+ flagging_mode="manual",
52
+ save_history=True,
53
+ )
54
  model_selector = gr.Dropdown(
55
+ choices=list(MODEL_NAMES.keys()),
56
+ value="DeepSeek-R1-Distill-Llama-8B",
57
+ label="Select Model",
58
  )
59
 
60
+ # 使用 gr.Blocks 的布局功能來組織元件
61
+ app.add_component(chat_interface)
62
+ app.add_component(model_selector)
63
+
64
+ app.launch()