Update app.py
Browse files
app.py
CHANGED
@@ -26,12 +26,16 @@ def load_model(model_path):
|
|
26 |
)
|
27 |
return model, tokenizer
|
28 |
|
29 |
-
|
|
|
|
|
30 |
|
31 |
def chat(message, history, model_name):
|
32 |
-
global current_model, current_tokenizer
|
33 |
|
34 |
-
|
|
|
|
|
35 |
current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])
|
36 |
|
37 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -43,18 +47,20 @@ def chat(message, history, model_name):
|
|
43 |
|
44 |
with gr.Blocks() as app:
|
45 |
gr.Markdown("## Chatbot with DeepSeek Models")
|
|
|
46 |
model_selector = gr.Dropdown(
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
51 |
chat_interface = gr.ChatInterface(
|
52 |
-
chat,
|
53 |
type="messages",
|
54 |
flagging_mode="manual",
|
55 |
save_history=True,
|
56 |
)
|
57 |
|
|
|
58 |
|
59 |
-
|
60 |
-
app.launch()
|
|
|
26 |
)
|
27 |
return model, tokenizer
|
28 |
|
29 |
+
# 初始化預設模型
|
30 |
+
current_model_name = "DeepSeek-R1-Distill-Llama-8B"
|
31 |
+
current_model, current_tokenizer = load_model(MODEL_NAMES[current_model_name])
|
32 |
|
33 |
def chat(message, history, model_name):
|
34 |
+
global current_model, current_tokenizer, current_model_name
|
35 |
|
36 |
+
# 檢查是否需要更換模型
|
37 |
+
if model_name != current_model_name:
|
38 |
+
current_model_name = model_name
|
39 |
current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])
|
40 |
|
41 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
47 |
|
48 |
with gr.Blocks() as app:
|
49 |
gr.Markdown("## Chatbot with DeepSeek Models")
|
50 |
+
|
51 |
model_selector = gr.Dropdown(
|
52 |
+
choices=list(MODEL_NAMES.keys()),
|
53 |
+
value=current_model_name,
|
54 |
+
label="Select Model",
|
55 |
+
)
|
56 |
+
|
57 |
chat_interface = gr.ChatInterface(
|
58 |
+
fn=lambda message, history: chat(message, history, model_selector.value),
|
59 |
type="messages",
|
60 |
flagging_mode="manual",
|
61 |
save_history=True,
|
62 |
)
|
63 |
|
64 |
+
model_selector.change(fn=lambda model_name: None, inputs=[model_selector], outputs=[])
|
65 |
|
66 |
+
app.launch()
|
|