Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -28,6 +28,7 @@ print("This might take a few minutes, especially on the first launch...")
|
|
28 |
model = None
|
29 |
tokenizer = None
|
30 |
load_successful = False
|
|
|
31 |
|
32 |
try:
|
33 |
start_load_time = time.time()
|
@@ -35,11 +36,11 @@ try:
|
|
35 |
MODEL_ID,
|
36 |
torch_dtype=torch.float32,
|
37 |
device_map="cpu",
|
38 |
-
# force_download=True #
|
39 |
)
|
40 |
tokenizer = AutoTokenizer.from_pretrained(
|
41 |
MODEL_ID,
|
42 |
-
# force_download=True #
|
43 |
)
|
44 |
model.eval()
|
45 |
load_time = time.time() - start_load_time
|
@@ -48,14 +49,14 @@ try:
|
|
48 |
|
49 |
# --- Stop Token Configuration ---
|
50 |
stop_token_strings = ["<|endofturn|>", "<|stop|>"]
|
51 |
-
|
52 |
|
53 |
-
if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in
|
54 |
-
|
55 |
elif tokenizer.eos_token_id is None:
|
56 |
print("Warning: tokenizer.eos_token_id is None. Cannot add to stop tokens.")
|
57 |
|
58 |
-
stop_token_ids_list = [tid for tid in
|
59 |
|
60 |
if not stop_token_ids_list:
|
61 |
print("Warning: Could not find any stop token IDs. Using default EOS if available, otherwise generation might not stop correctly.")
|
@@ -63,7 +64,7 @@ try:
|
|
63 |
stop_token_ids_list = [tokenizer.eos_token_id]
|
64 |
else:
|
65 |
print("Error: No stop tokens found, including default EOS. Generation may run indefinitely.")
|
66 |
-
#
|
67 |
|
68 |
print(f"Using Stop Token IDs: {stop_token_ids_list}")
|
69 |
|
@@ -72,7 +73,7 @@ except Exception as e:
|
|
72 |
if 'model' in locals() and model is not None: del model
|
73 |
if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
|
74 |
gc.collect()
|
75 |
-
#
|
76 |
raise gr.Error(f"Failed to load the model {MODEL_ID}. Cannot start the application. Error: {e}")
|
77 |
|
78 |
|
@@ -81,7 +82,7 @@ def get_system_prompt():
|
|
81 |
current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)")
|
82 |
return (
|
83 |
f"- AI μΈμ΄λͺ¨λΈμ μ΄λ¦μ \"CLOVA X\" μ΄λ©° λ€μ΄λ²μμ λ§λ€μλ€.\n"
|
84 |
-
# f"- μ€λμ {current_date}μ΄λ€.\n" #
|
85 |
f"- μ¬μ©μμ μ§λ¬Έμ λν΄ μΉμ νκ³ μμΈνκ² νκ΅μ΄λ‘ λ΅λ³ν΄μΌ νλ€."
|
86 |
)
|
87 |
|
@@ -109,16 +110,22 @@ def warmup_model():
|
|
109 |
return_tensors="pt"
|
110 |
).to("cpu")
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
with torch.no_grad():
|
113 |
-
output_ids = model.generate(
|
114 |
-
|
115 |
-
|
116 |
-
eos_token_id=stop_token_ids_list,
|
117 |
-
pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
|
118 |
-
do_sample=False # Warm-up μμλ μνλ§ λΆνμ
|
119 |
-
)
|
120 |
-
|
121 |
-
# κ²°κ³Ό λμ½λ© (μ ν μ¬ν, νμΈμ©)
|
122 |
# response = tokenizer.decode(output_ids[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
123 |
# print(f"Warm-up response (decoded): {response}")
|
124 |
|
@@ -130,40 +137,43 @@ def warmup_model():
|
|
130 |
|
131 |
except Exception as e:
|
132 |
print(f"!!! Error during model warm-up: {e}")
|
133 |
-
# μμ
μ€ν¨κ° μ± μ€νμ λ§μ§λ μλλ‘ μ²λ¦¬
|
134 |
finally:
|
135 |
-
gc.collect()
|
136 |
-
|
137 |
|
138 |
# --- Inference Function ---
|
139 |
def predict(message, history):
|
140 |
"""
|
141 |
-
Generates response using HyperCLOVAX
|
142 |
-
|
143 |
-
Assumes 'history' is in the Gradio 'messages' format: List[List[str | None | Tuple]] or List[Dict]
|
144 |
"""
|
145 |
if model is None or tokenizer is None:
|
146 |
return "μ€λ₯: λͺ¨λΈμ΄ λ‘λλμ§ μμμ΅λλ€."
|
147 |
|
148 |
system_prompt = get_system_prompt()
|
149 |
|
150 |
-
#
|
151 |
chat_history_formatted = [
|
152 |
-
{"role": "tool_list", "content": ""},
|
153 |
{"role": "system", "content": system_prompt}
|
154 |
]
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
chat_history_formatted.append({"role": "user", "content": message})
|
168 |
|
169 |
inputs = None
|
@@ -175,41 +185,47 @@ def predict(message, history):
|
|
175 |
add_generation_prompt=True,
|
176 |
return_dict=True,
|
177 |
return_tensors="pt"
|
178 |
-
).to("cpu")
|
179 |
input_length = inputs['input_ids'].shape[1]
|
180 |
print(f"\nInput tokens: {input_length}")
|
181 |
|
182 |
except Exception as e:
|
183 |
print(f"!!! Error applying chat template: {e}")
|
184 |
-
# Provide feedback to the user
|
185 |
return f"μ€λ₯: μ
λ ₯ νμμ μ²λ¦¬νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€. ({e})"
|
186 |
|
187 |
try:
|
188 |
print("Generating response...")
|
189 |
generation_start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
with torch.no_grad():
|
191 |
-
output_ids = model.generate(
|
192 |
-
|
193 |
-
max_new_tokens=MAX_NEW_TOKENS,
|
194 |
-
eos_token_id=stop_token_ids_list,
|
195 |
-
pad_token_id=tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
|
196 |
-
do_sample=True,
|
197 |
-
temperature=0.7,
|
198 |
-
top_p=0.9,
|
199 |
-
)
|
200 |
generation_time = time.time() - generation_start_time
|
201 |
print(f"Generation complete in {generation_time:.2f} seconds.")
|
202 |
|
203 |
except Exception as e:
|
204 |
print(f"!!! Error during model generation: {e}")
|
205 |
-
# Clean up potentially large tensors in case of error
|
206 |
if inputs is not None: del inputs
|
207 |
if output_ids is not None: del output_ids
|
208 |
gc.collect()
|
209 |
return f"μ€λ₯: μλ΅μ μμ±νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€. ({e})"
|
210 |
|
211 |
# Decode the response
|
212 |
-
response = "μ€λ₯: μλ΅ μμ±μ μ€ν¨νμ΅λλ€."
|
213 |
if output_ids is not None:
|
214 |
try:
|
215 |
new_tokens = output_ids[0, input_length:]
|
@@ -220,7 +236,6 @@ def predict(message, history):
|
|
220 |
print(f"!!! Error decoding response: {e}")
|
221 |
response = "μ€λ₯: μλ΅μ οΏ½οΏ½οΏ½μ½λ©νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€."
|
222 |
|
223 |
-
|
224 |
# Clean up memory
|
225 |
if inputs is not None: del inputs
|
226 |
if output_ids is not None: del output_ids
|
@@ -232,13 +247,8 @@ def predict(message, history):
|
|
232 |
# --- Gradio Interface Setup ---
|
233 |
print("--- Setting up Gradio Interface ---")
|
234 |
|
235 |
-
#
|
236 |
-
chatbot_component = gr.Chatbot(
|
237 |
-
label="HyperCLOVA X SEED (0.5B) λν",
|
238 |
-
bubble_full_width=False,
|
239 |
-
height=600,
|
240 |
-
type='messages' # μ΄ λΆλΆμ λͺ
μνμ¬ ChatInterfaceμμ νΈνμ± ν보
|
241 |
-
)
|
242 |
|
243 |
examples = [
|
244 |
["λ€μ΄λ² ν΄λ‘λ°Xλ 무μμΈκ°μ?"],
|
@@ -247,34 +257,32 @@ examples = [
|
|
247 |
["μ μ£Όλ μ¬ν κ³νμ μΈμ°κ³ μλλ°, 3λ° 4μΌ μΆμ² μ½μ€ μ’ μ§μ€λ?"],
|
248 |
]
|
249 |
|
250 |
-
# ChatInterface
|
|
|
251 |
demo = gr.ChatInterface(
|
252 |
-
fn=predict, #
|
253 |
-
chatbot=chatbot_component, #
|
254 |
title="π°π· λ€μ΄λ² HyperCLOVA X SEED (0.5B) λ°λͺ¨",
|
255 |
description=(
|
256 |
f"**λͺ¨λΈ:** {MODEL_ID}\n"
|
257 |
f"**νκ²½:** Hugging Face λ¬΄λ£ CPU (16GB RAM)\n"
|
258 |
-
f"**μ£Όμ:** CPUμμ μ€νλλ―λ‘ μλ΅ μμ±μ λ€μ μκ°μ΄ 걸릴 μ μμ΅λλ€. (μμ
|
259 |
f"μ΅λ μμ± ν ν° μλ {MAX_NEW_TOKENS}κ°λ‘ μ νλ©λλ€."
|
260 |
),
|
261 |
examples=examples,
|
262 |
-
cache_examples=False,
|
263 |
theme="soft",
|
264 |
-
# retry_btn, undo_btn, clear_btn λ±μ μ΅μ λ²μ μμ μ§μ μ§μνμ§ μμ
|
265 |
)
|
266 |
|
267 |
# --- Application Launch ---
|
268 |
if __name__ == "__main__":
|
269 |
-
# λͺ¨λΈ λ‘λ© μ±κ³΅ μμλ§ μμ
μ€ν
|
270 |
if load_successful:
|
271 |
warmup_model()
|
272 |
else:
|
273 |
print("Skipping warm-up because model loading failed.")
|
274 |
|
275 |
print("--- Launching Gradio App ---")
|
276 |
-
# queue()λ μ¬λ¬ μ¬μ©μ μ²λ¦¬ λ° κΈ΄ μμ
κ΄λ¦¬μ μ μ©
|
277 |
demo.queue().launch(
|
278 |
-
# share=True #
|
279 |
-
# server_name="0.0.0.0" #
|
280 |
)
|
|
|
28 |
model = None
|
29 |
tokenizer = None
|
30 |
load_successful = False
|
31 |
+
stop_token_ids_list = [] # Initialize stop_token_ids_list
|
32 |
|
33 |
try:
|
34 |
start_load_time = time.time()
|
|
|
36 |
MODEL_ID,
|
37 |
torch_dtype=torch.float32,
|
38 |
device_map="cpu",
|
39 |
+
# force_download=True # Keep commented unless cache issues reappear
|
40 |
)
|
41 |
tokenizer = AutoTokenizer.from_pretrained(
|
42 |
MODEL_ID,
|
43 |
+
# force_download=True # Keep commented
|
44 |
)
|
45 |
model.eval()
|
46 |
load_time = time.time() - start_load_time
|
|
|
49 |
|
50 |
# --- Stop Token Configuration ---
|
51 |
stop_token_strings = ["<|endofturn|>", "<|stop|>"]
|
52 |
+
temp_stop_ids = [tokenizer.convert_tokens_to_ids(token) for token in stop_token_strings]
|
53 |
|
54 |
+
if tokenizer.eos_token_id is not None and tokenizer.eos_token_id not in temp_stop_ids:
|
55 |
+
temp_stop_ids.append(tokenizer.eos_token_id)
|
56 |
elif tokenizer.eos_token_id is None:
|
57 |
print("Warning: tokenizer.eos_token_id is None. Cannot add to stop tokens.")
|
58 |
|
59 |
+
stop_token_ids_list = [tid for tid in temp_stop_ids if tid is not None] # Assign to the global scope variable
|
60 |
|
61 |
if not stop_token_ids_list:
|
62 |
print("Warning: Could not find any stop token IDs. Using default EOS if available, otherwise generation might not stop correctly.")
|
|
|
64 |
stop_token_ids_list = [tokenizer.eos_token_id]
|
65 |
else:
|
66 |
print("Error: No stop tokens found, including default EOS. Generation may run indefinitely.")
|
67 |
+
# Consider raising an error or setting a default if this is critical
|
68 |
|
69 |
print(f"Using Stop Token IDs: {stop_token_ids_list}")
|
70 |
|
|
|
73 |
if 'model' in locals() and model is not None: del model
|
74 |
if 'tokenizer' in locals() and tokenizer is not None: del tokenizer
|
75 |
gc.collect()
|
76 |
+
# Raise Gradio error to display in the Space UI if loading fails
|
77 |
raise gr.Error(f"Failed to load the model {MODEL_ID}. Cannot start the application. Error: {e}")
|
78 |
|
79 |
|
|
|
82 |
current_date = datetime.datetime.now().strftime("%Y-%m-%d (%A)")
|
83 |
return (
|
84 |
f"- AI μΈμ΄λͺ¨λΈμ μ΄λ¦μ \"CLOVA X\" μ΄λ©° λ€μ΄λ²μμ λ§λ€μλ€.\n"
|
85 |
+
# f"- μ€λμ {current_date}μ΄λ€.\n" # Uncomment if needed
|
86 |
f"- μ¬μ©μμ μ§λ¬Έμ λν΄ μΉμ νκ³ μμΈνκ² νκ΅μ΄λ‘ λ΅λ³ν΄μΌ νλ€."
|
87 |
)
|
88 |
|
|
|
110 |
return_tensors="pt"
|
111 |
).to("cpu")
|
112 |
|
113 |
+
# Check if stop_token_ids_list is empty and handle appropriately
|
114 |
+
gen_kwargs = {
|
115 |
+
"max_new_tokens": 10,
|
116 |
+
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
|
117 |
+
"do_sample": False
|
118 |
+
}
|
119 |
+
if stop_token_ids_list:
|
120 |
+
gen_kwargs["eos_token_id"] = stop_token_ids_list
|
121 |
+
else:
|
122 |
+
print("Warmup Warning: No stop tokens defined for generation.")
|
123 |
+
|
124 |
+
|
125 |
with torch.no_grad():
|
126 |
+
output_ids = model.generate(**inputs, **gen_kwargs)
|
127 |
+
|
128 |
+
# Optional: Decode warmup response for verification
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
# response = tokenizer.decode(output_ids[0, inputs['input_ids'].shape[1]:], skip_special_tokens=True)
|
130 |
# print(f"Warm-up response (decoded): {response}")
|
131 |
|
|
|
137 |
|
138 |
except Exception as e:
|
139 |
print(f"!!! Error during model warm-up: {e}")
|
|
|
140 |
finally:
|
141 |
+
gc.collect()
|
|
|
142 |
|
143 |
# --- Inference Function ---
|
144 |
def predict(message, history):
|
145 |
"""
|
146 |
+
Generates response using HyperCLOVAX.
|
147 |
+
Assumes 'history' is in the Gradio 'messages' format: List[Dict].
|
|
|
148 |
"""
|
149 |
if model is None or tokenizer is None:
|
150 |
return "μ€λ₯: λͺ¨λΈμ΄ λ‘λλμ§ μμμ΅λλ€."
|
151 |
|
152 |
system_prompt = get_system_prompt()
|
153 |
|
154 |
+
# Start with system prompt
|
155 |
chat_history_formatted = [
|
156 |
+
{"role": "tool_list", "content": ""}, # As required by model card
|
157 |
{"role": "system", "content": system_prompt}
|
158 |
]
|
159 |
+
|
160 |
+
# Append history (List of {'role': 'user'/'assistant', 'content': '...'})
|
161 |
+
if isinstance(history, list): # Check if history is a list
|
162 |
+
for turn in history:
|
163 |
+
# Validate turn format
|
164 |
+
if isinstance(turn, dict) and "role" in turn and "content" in turn:
|
165 |
+
chat_history_formatted.append(turn)
|
166 |
+
# Handle potential older tuple format if necessary (though less likely now)
|
167 |
+
elif isinstance(turn, (list, tuple)) and len(turn) == 2:
|
168 |
+
print(f"Warning: Received history item in tuple format: {turn}. Converting to messages format.")
|
169 |
+
chat_history_formatted.append({"role": "user", "content": turn[0]})
|
170 |
+
if turn[1]: # Ensure assistant message exists
|
171 |
+
chat_history_formatted.append({"role": "assistant", "content": turn[1]})
|
172 |
+
else:
|
173 |
+
print(f"Warning: Skipping unexpected history format item: {turn}")
|
174 |
+
|
175 |
+
|
176 |
+
# Append the latest user message
|
177 |
chat_history_formatted.append({"role": "user", "content": message})
|
178 |
|
179 |
inputs = None
|
|
|
185 |
add_generation_prompt=True,
|
186 |
return_dict=True,
|
187 |
return_tensors="pt"
|
188 |
+
).to("cpu")
|
189 |
input_length = inputs['input_ids'].shape[1]
|
190 |
print(f"\nInput tokens: {input_length}")
|
191 |
|
192 |
except Exception as e:
|
193 |
print(f"!!! Error applying chat template: {e}")
|
|
|
194 |
return f"μ€λ₯: μ
λ ₯ νμμ μ²λ¦¬νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€. ({e})"
|
195 |
|
196 |
try:
|
197 |
print("Generating response...")
|
198 |
generation_start_time = time.time()
|
199 |
+
|
200 |
+
# Prepare generation arguments, handling empty stop_token_ids_list
|
201 |
+
gen_kwargs = {
|
202 |
+
"max_new_tokens": MAX_NEW_TOKENS,
|
203 |
+
"pad_token_id": tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.pad_token_id,
|
204 |
+
"do_sample": True,
|
205 |
+
"temperature": 0.7,
|
206 |
+
"top_p": 0.9,
|
207 |
+
}
|
208 |
+
if stop_token_ids_list:
|
209 |
+
gen_kwargs["eos_token_id"] = stop_token_ids_list
|
210 |
+
else:
|
211 |
+
print("Generation Warning: No stop tokens defined.")
|
212 |
+
|
213 |
+
|
214 |
with torch.no_grad():
|
215 |
+
output_ids = model.generate(**inputs, **gen_kwargs)
|
216 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
generation_time = time.time() - generation_start_time
|
218 |
print(f"Generation complete in {generation_time:.2f} seconds.")
|
219 |
|
220 |
except Exception as e:
|
221 |
print(f"!!! Error during model generation: {e}")
|
|
|
222 |
if inputs is not None: del inputs
|
223 |
if output_ids is not None: del output_ids
|
224 |
gc.collect()
|
225 |
return f"μ€λ₯: μλ΅μ μμ±νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€. ({e})"
|
226 |
|
227 |
# Decode the response
|
228 |
+
response = "μ€λ₯: μλ΅ μμ±μ μ€ν¨νμ΅λλ€."
|
229 |
if output_ids is not None:
|
230 |
try:
|
231 |
new_tokens = output_ids[0, input_length:]
|
|
|
236 |
print(f"!!! Error decoding response: {e}")
|
237 |
response = "μ€λ₯: μλ΅μ οΏ½οΏ½οΏ½μ½λ©νλ μ€ λ¬Έμ κ° λ°μνμ΅λλ€."
|
238 |
|
|
|
239 |
# Clean up memory
|
240 |
if inputs is not None: del inputs
|
241 |
if output_ids is not None: del output_ids
|
|
|
247 |
# --- Gradio Interface Setup ---
|
248 |
print("--- Setting up Gradio Interface ---")
|
249 |
|
250 |
+
# No need to create a separate Chatbot component beforehand
|
251 |
+
# chatbot_component = gr.Chatbot(...) # REMOVED
|
|
|
|
|
|
|
|
|
|
|
252 |
|
253 |
examples = [
|
254 |
["λ€μ΄λ² ν΄λ‘λ°Xλ 무μμΈκ°μ?"],
|
|
|
257 |
["μ μ£Όλ μ¬ν κ³νμ μΈμ°κ³ μλλ°, 3λ° 4μΌ μΆμ² μ½μ€ μ’ μ§μ€λ?"],
|
258 |
]
|
259 |
|
260 |
+
# Let ChatInterface manage its own internal Chatbot component
|
261 |
+
# Remove the chatbot=... argument
|
262 |
demo = gr.ChatInterface(
|
263 |
+
fn=predict, # Link the prediction function
|
264 |
+
# chatbot=chatbot_component, # REMOVED
|
265 |
title="π°π· λ€μ΄λ² HyperCLOVA X SEED (0.5B) λ°λͺ¨",
|
266 |
description=(
|
267 |
f"**λͺ¨λΈ:** {MODEL_ID}\n"
|
268 |
f"**νκ²½:** Hugging Face λ¬΄λ£ CPU (16GB RAM)\n"
|
269 |
+
f"**μ£Όμ:** CPUμμ μ€νλλ―λ‘ μλ΅ μμ±μ λ€μ μκ°μ΄ 걸릴 μ μμ΅λλ€. (μμ
μλ£)\n"
|
270 |
f"μ΅λ μμ± ν ν° μλ {MAX_NEW_TOKENS}κ°λ‘ μ νλ©λλ€."
|
271 |
),
|
272 |
examples=examples,
|
273 |
+
cache_examples=False,
|
274 |
theme="soft",
|
|
|
275 |
)
|
276 |
|
277 |
# --- Application Launch ---
|
278 |
if __name__ == "__main__":
|
|
|
279 |
if load_successful:
|
280 |
warmup_model()
|
281 |
else:
|
282 |
print("Skipping warm-up because model loading failed.")
|
283 |
|
284 |
print("--- Launching Gradio App ---")
|
|
|
285 |
demo.queue().launch(
|
286 |
+
# share=True # Uncomment for public link
|
287 |
+
# server_name="0.0.0.0" # Uncomment for local network access
|
288 |
)
|