File size: 2,262 Bytes
2c8d1ab
b06ae2b
2c8d1ab
83fff6c
2c8d1ab
 
5cd7fd5
 
2c8d1ab
 
83fff6c
 
705c0ca
 
83fff6c
 
 
da4f073
 
 
b82cda0
 
705c0ca
 
0de4c24
b06ae2b
0de4c24
 
b06ae2b
b82cda0
0de4c24
2c8d1ab
 
da4f073
c4a5299
 
 
705c0ca
da4f073
2c8d1ab
c4a5299
705c0ca
c4a5299
 
 
b06ae2b
705c0ca
3eecdbe
 
2c8d1ab
 
705c0ca
2c8d1ab
 
da4f073
2c8d1ab
 
705c0ca
90f794a
c4a5299
 
 
 
 
90f794a
c4a5299
90f794a
 
 
 
 
da4f073
 
 
705c0ca
c4a5299
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import os

MODEL_NAMES = {
    "DeepSeek-R1-Distill-Qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
    "DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}

HF_TOKEN = os.getenv("HF_TOKEN")


def load_model(model_path):
    tokenizer = AutoTokenizer.from_pretrained(
        model_path, trust_remote_code=True, token=HF_TOKEN
    )
    config = AutoConfig.from_pretrained(
        model_path, trust_remote_code=True, token=HF_TOKEN
    )
    if hasattr(config, "quantization_config"):
        del config.quantization_config  # 刪除量化配置,避免使用 FP8

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        config=config,
        trust_remote_code=True,
        token=HF_TOKEN,
        torch_dtype=torch.float16,
        device_map="auto",
    )
    return model, tokenizer


# 初始化預設模型
current_model_name = "DeepSeek-R1-Distill-Llama-8B"
current_model, current_tokenizer = load_model(MODEL_NAMES[current_model_name])


def chat(message, history, model_name):
    global current_model, current_tokenizer, current_model_name

    # 檢查是否需要更換模型
    if model_name != current_model_name:
        current_model_name = model_name
        current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])

    device = "cuda" if torch.cuda.is_available() else "cpu"
    inputs = current_tokenizer(message, return_tensors="pt").to(device)
    outputs = current_model.generate(**inputs, max_length=1024)
    response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)

    return response


with gr.Blocks() as app:
    gr.Markdown("## Chatbot with DeepSeek Models")

    model_selector = gr.Dropdown(
        choices=list(MODEL_NAMES.keys()),
        value=current_model_name,
        label="Select Model",
    )

    chat_interface = gr.ChatInterface(
        fn=lambda message, history: chat(message, history, model_selector.value),
        type="messages",
        flagging_mode="manual",
        save_history=True,
    )

    model_selector.change(
        fn=lambda model_name: None, inputs=[model_selector], outputs=[]
    )

app.launch()