wambugu71 commited on
Commit
8e2adee
·
verified ·
1 Parent(s): f6682e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import spaces
4
+ from transformers import GemmaTokenizer, AutoModelForCausalLM
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
6
+ from threading import Thread
7
+
8
+ # Set an environment variable
9
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
10
+
11
+
12
+ DESCRIPTION = '''
13
+ <div>
14
+ <h1 style="text-align: center;">unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit</h1>
15
+ </div>
16
+ '''
17
+
18
+ LICENSE = """
19
+ <p/>
20
+ ---
21
+ """
22
+
23
+ PLACEHOLDER = """
24
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
25
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">DeepSeek-R1-Distill-Qwen-32B-bnb-4bit</h1>
26
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
27
+ </div>
28
+ """
29
+
30
+
31
+ css = """
32
+ h1 {
33
+ text-align: center;
34
+ display: block;
35
+ }
36
+ #duplicate-button {
37
+ margin: auto;
38
+ color: white;
39
+ background: #1565c0;
40
+ border-radius: 100vh;
41
+ }
42
+ """
43
+
44
+ # Load the tokenizer and model
45
+ tokenizer = AutoTokenizer.from_pretrained("unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit")
46
+ tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}"
47
+
48
+ model = AutoModelForCausalLM.from_pretrained("unsloth/DeepSeek-R1-Distill-Qwen-32B-bnb-4bit", device_map="auto") # to("cuda:0")
49
+ terminators = [
50
+ tokenizer.eos_token_id,
51
+ ]
52
+
53
+ @spaces.GPU(duration=120)
54
+ def chat_llama3_8b(message: str,
55
+ history: list,
56
+ temperature: float,
57
+ max_new_tokens: int
58
+ ) -> str:
59
+ """
60
+ Generate a streaming response using the llama3-8b model.
61
+ Args:
62
+ message (str): The input message.
63
+ history (list): The conversation history used by ChatInterface.
64
+ temperature (float): The temperature for generating the response.
65
+ max_new_tokens (int): The maximum number of new tokens to generate.
66
+ Returns:
67
+ str: The generated response.
68
+ """
69
+ conversation = []
70
+ for user, assistant in history:
71
+ conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
72
+ conversation.append({"role": "user", "content": message})
73
+
74
+ input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt", add_generation_prompt=True).to(model.device)
75
+ # for debug
76
+ print(tokenizer.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False))
77
+ print(input_ids)
78
+
79
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
80
+
81
+ generate_kwargs = dict(
82
+ input_ids= input_ids,
83
+ streamer=streamer,
84
+ max_new_tokens=max_new_tokens,
85
+ do_sample=True,
86
+ temperature=temperature,
87
+ eos_token_id=terminators,
88
+ )
89
+ # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
90
+ if temperature == 0:
91
+ generate_kwargs['do_sample'] = False
92
+
93
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
94
+ t.start()
95
+
96
+ outputs = []
97
+ for text in streamer:
98
+ # Remove thinking tags to prevent Gradio display issues
99
+ if "<think>" in text:
100
+ text = text.replace("<think>", "[think]").strip()
101
+ if "</think>" in text:
102
+ text = text.replace("</think>", "[/think]").strip()
103
+ outputs.append(text)
104
+ print("".join(outputs))
105
+ yield "".join(outputs)
106
+
107
+ # Gradio block
108
+ chatbot=gr.Chatbot(height=450, placeholder=PLACEHOLDER, label='Gradio ChatInterface')
109
+
110
+ with gr.Blocks(fill_height=True, css=css) as demo:
111
+
112
+ gr.Markdown(DESCRIPTION)
113
+ gr.ChatInterface(
114
+ fn=chat_llama3_8b,
115
+ chatbot=chatbot,
116
+ fill_height=True,
117
+ additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
118
+ additional_inputs=[
119
+ gr.Slider(minimum=0,
120
+ maximum=1,
121
+ step=0.1,
122
+ value=0.5,
123
+ label="Temperature",
124
+ render=False),
125
+ gr.Slider(minimum=128,
126
+ maximum=4096,
127
+ step=1,
128
+ value=1024,
129
+ label="Max new tokens",
130
+ render=False ),
131
+ ],
132
+ examples=[
133
+ ['How to setup a human base on Mars? Give short answer.'],
134
+ ['Explain theory of relativity to me like I’m 8 years old.'],
135
+ ['What is 9,000 * 9,000?'],
136
+ ['Write a pun-filled happy birthday message to my friend Alex.'],
137
+ ['Justify why a penguin might make a good king of the jungle.']
138
+ ],
139
+ cache_examples=False,
140
+ )
141
+
142
+ gr.Markdown(LICENSE)
143
+
144
+ if __name__ == "__main__":
145
+ demo.launch()