File size: 2,262 Bytes
2c8d1ab b06ae2b 2c8d1ab 83fff6c 2c8d1ab 5cd7fd5 2c8d1ab 83fff6c 705c0ca 83fff6c da4f073 b82cda0 705c0ca 0de4c24 b06ae2b 0de4c24 b06ae2b b82cda0 0de4c24 2c8d1ab da4f073 c4a5299 705c0ca da4f073 2c8d1ab c4a5299 705c0ca c4a5299 b06ae2b 705c0ca 3eecdbe 2c8d1ab 705c0ca 2c8d1ab da4f073 2c8d1ab 705c0ca 90f794a c4a5299 90f794a c4a5299 90f794a da4f073 705c0ca c4a5299 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
import os
MODEL_NAMES = {
"DeepSeek-R1-Distill-Qwen-7B": "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
"DeepSeek-R1-Distill-Llama-8B": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
}
HF_TOKEN = os.getenv("HF_TOKEN")
def load_model(model_path):
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, token=HF_TOKEN
)
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=True, token=HF_TOKEN
)
if hasattr(config, "quantization_config"):
del config.quantization_config # 刪除量化配置,避免使用 FP8
model = AutoModelForCausalLM.from_pretrained(
model_path,
config=config,
trust_remote_code=True,
token=HF_TOKEN,
torch_dtype=torch.float16,
device_map="auto",
)
return model, tokenizer
# 初始化預設模型
current_model_name = "DeepSeek-R1-Distill-Llama-8B"
current_model, current_tokenizer = load_model(MODEL_NAMES[current_model_name])
def chat(message, history, model_name):
global current_model, current_tokenizer, current_model_name
# 檢查是否需要更換模型
if model_name != current_model_name:
current_model_name = model_name
current_model, current_tokenizer = load_model(MODEL_NAMES[model_name])
device = "cuda" if torch.cuda.is_available() else "cpu"
inputs = current_tokenizer(message, return_tensors="pt").to(device)
outputs = current_model.generate(**inputs, max_length=1024)
response = current_tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
with gr.Blocks() as app:
gr.Markdown("## Chatbot with DeepSeek Models")
model_selector = gr.Dropdown(
choices=list(MODEL_NAMES.keys()),
value=current_model_name,
label="Select Model",
)
chat_interface = gr.ChatInterface(
fn=lambda message, history: chat(message, history, model_selector.value),
type="messages",
flagging_mode="manual",
save_history=True,
)
model_selector.change(
fn=lambda model_name: None, inputs=[model_selector], outputs=[]
)
app.launch()
|