unausagi commited on
Commit
c4a5299
·
verified ·
1 Parent(s): 90f794a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -10
app.py CHANGED
@@ -26,12 +26,16 @@ def load_model(model_path):
26
  )
27
  return model, tokenizer
28
 
29
- current_model, current_tokenizer = load_model("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
 
 
30
 
31
  def chat(message, history, model_name):
32
- global current_model, current_tokenizer
33
 
34
- if model_name != current_model:
 
 
35
  current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])
36
 
37
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -43,18 +47,20 @@ def chat(message, history, model_name):
43
 
44
  with gr.Blocks() as app:
45
  gr.Markdown("## Chatbot with DeepSeek Models")
 
46
  model_selector = gr.Dropdown(
47
- choices=list(MODEL_NAMES.keys()),
48
- value="DeepSeek-R1-Distill-Llama-8B",
49
- label="Select Model",
50
- )
 
51
  chat_interface = gr.ChatInterface(
52
- chat,
53
  type="messages",
54
  flagging_mode="manual",
55
  save_history=True,
56
  )
57
 
 
58
 
59
-
60
- app.launch()
 
26
  )
27
  return model, tokenizer
28
 
29
+ # 初始化預設模型
30
+ current_model_name = "DeepSeek-R1-Distill-Llama-8B"
31
+ current_model, current_tokenizer = load_model(MODEL_NAMES[current_model_name])
32
 
33
  def chat(message, history, model_name):
34
+ global current_model, current_tokenizer, current_model_name
35
 
36
+ # 檢查是否需要更換模型
37
+ if model_name != current_model_name:
38
+ current_model_name = model_name
39
  current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])
40
 
41
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
47
 
48
  with gr.Blocks() as app:
49
  gr.Markdown("## Chatbot with DeepSeek Models")
50
+
51
  model_selector = gr.Dropdown(
52
+ choices=list(MODEL_NAMES.keys()),
53
+ value=current_model_name,
54
+ label="Select Model",
55
+ )
56
+
57
  chat_interface = gr.ChatInterface(
58
+ fn=lambda message, history: chat(message, history, model_selector.value),
59
  type="messages",
60
  flagging_mode="manual",
61
  save_history=True,
62
  )
63
 
64
+ model_selector.change(fn=lambda model_name: None, inputs=[model_selector], outputs=[])
65
 
66
+ app.launch()