File size: 5,232 Bytes
3bdc527
 
8e7da8f
3bdc527
8e7da8f
 
 
 
 
3bdc527
c150a39
b1e4845
 
 
 
c150a39
 
8e7da8f
 
 
71c9089
8e7da8f
 
05fc456
 
 
 
 
 
 
 
 
 
 
 
 
8e7da8f
 
5fa88f4
 
 
 
 
 
8e7da8f
 
 
 
 
b1e4845
8e7da8f
 
 
 
 
 
 
 
5fa88f4
 
 
 
 
8e7da8f
5fa88f4
 
 
 
 
 
 
 
 
 
 
 
8e7da8f
5fa88f4
8e7da8f
5fa88f4
 
 
 
 
 
 
 
 
 
8e7da8f
5fa88f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e7da8f
 
5fa88f4
3bdc527
8e7da8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fa88f4
 
8e7da8f
 
 
 
5fa88f4
8e7da8f
5fa88f4
8e7da8f
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import spaces
import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">TAIDE/Llama3-TAIDE-LX-8B-Chat-Alpha1</h1>
<p>This Space demonstrates the instruction-tuned model <a href="https://huggingface.co/taide/Llama3-TAIDE-LX-8B-Chat-Alpha1"><b>Llama3-TAIDE-LX-8B-Chat-Alpha1</b></a>. Llama3-TAIDE-LX-8B is the new open LLM and comes in one sizes: 8b. Feel free to play with it, or duplicate to run privately!</p>
</div>
'''

LICENSE = """
<p/>
---
Built with TAIDE-LX-8B-Chat
"""

css = """
h1 {
  text-align: center;
  display: block;
}
#duplicate-button {
  margin: auto;
  color: white;
  background: #1565c0;
  border-radius: 100vh;
}
"""

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("taide/Llama3-TAIDE-LX-8B-Chat-Alpha1")
model = AutoModelForCausalLM.from_pretrained("taide/Llama3-TAIDE-LX-8B-Chat-Alpha1")

# 設定pad_token_id(關鍵修正)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

@spaces.GPU
def chat_taide_8b(message: str, 
              history: list, 
              temperature: float, 
              max_new_tokens: int
             ) -> str:
    """
    Generate a streaming response using the llama3-8b model.
    """
    try:
        conversation = []
        for user, assistant in history:
            conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
        conversation.append({"role": "user", "content": message})

        # 使用return_dict=True來獲取attention_mask(關鍵修正)
        inputs = tokenizer.apply_chat_template(
            conversation, 
            return_tensors="pt", 
            return_dict=True,
            add_generation_prompt=True
        )
        
        input_ids = inputs["input_ids"].to(model.device)
        attention_mask = inputs.get("attention_mask", None)
        if attention_mask is not None:
            attention_mask = attention_mask.to(model.device)
        
        streamer = TextIteratorStreamer(tokenizer, timeout=30.0, skip_prompt=True, skip_special_tokens=True)

        generate_kwargs = dict(
            input_ids=input_ids,
            attention_mask=attention_mask,  # 加入attention_mask
            streamer=streamer,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            eos_token_id=terminators,
            pad_token_id=tokenizer.pad_token_id,  # 明確設定pad_token_id
        )
        
        # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.             
        if temperature == 0:
            generate_kwargs['do_sample'] = False
            
        t = Thread(target=model.generate, kwargs=generate_kwargs)
        t.start()

        outputs = []
        for text in streamer:
            outputs.append(text)
            yield "".join(outputs)
            
    except Exception as e:
        yield f"生成過程中發生錯誤: {str(e)}"
    finally:
        # 清理GPU記憶體
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

# Gradio block
chatbot = gr.Chatbot(height=450, label='Gradio ChatInterface')

with gr.Blocks(fill_height=True, css=css) as demo:
    
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    gr.ChatInterface(
        fn=chat_taide_8b,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(minimum=0,
                      maximum=1, 
                      step=0.1,
                      value=0.95, 
                      label="Temperature", 
                      render=False),
            gr.Slider(minimum=128, 
                      maximum=4096,
                      step=1,
                      value=512, 
                      label="Max new tokens", 
                      render=False),
        ],
        examples=[
            ['請以以下內容為基礎,寫一篇文章:撰寫一篇作文,題目為《一張舊照片》,內容要求為:選擇一張令你印象深刻的照片,說明令你印象深刻的原因,並描述照片中的影像及背後的故事。記錄成長的過程、與他人的情景、環境變遷和美麗的景色。'],
            ['請以品牌經理的身份,給廣告公司的創意總監寫一封信,提出對於新產品廣告宣傳活動的創意建議。'],
            ['以下提供英文內容,請幫我翻譯成中文。Dongshan coffee is famous for its unique position, and the constant refinement of production methods. The flavor is admired by many caffeine afficionados.'],
        ],
        cache_examples=False,
    )
    
    gr.Markdown(LICENSE)
    
if __name__ == "__main__":
    demo.launch()