Spaces:
Running
on
Zero
Running
on
Zero
Commit
Β·
cdba9e2
1
Parent(s):
ddb47d4
Implement streaming inference display
Browse files- Added TextIteratorStreamer for real-time token streaming
- Response appears word-by-word as it's generated
- Added typing indicator (β) during generation
- Improved user experience with immediate visual feedback
- Used threading for non-blocking generation
- Added CSS animation for typing indicator
- Shows 'π Starting...' initially, then streams response
- Button shows 'β³ Generating...' during streaming
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
4 |
import spaces
|
5 |
|
6 |
# Model configuration
|
@@ -67,7 +68,7 @@ def format_prompt(message, history):
|
|
67 |
return prompt
|
68 |
|
69 |
@spaces.GPU(duration=60)
|
70 |
-
def
|
71 |
message,
|
72 |
history,
|
73 |
temperature=0.3,
|
@@ -75,7 +76,7 @@ def generate_response(
|
|
75 |
top_p=0.95,
|
76 |
repetition_penalty=1.05,
|
77 |
):
|
78 |
-
"""Generate response from the model"""
|
79 |
|
80 |
# Format the prompt
|
81 |
prompt = format_prompt(message, history)
|
@@ -89,23 +90,38 @@ def generate_response(
|
|
89 |
if k != 'token_type_ids': # Filter out token_type_ids
|
90 |
model_inputs[k] = v.to(model.device)
|
91 |
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
-
#
|
106 |
-
|
|
|
|
|
|
|
107 |
|
108 |
-
|
109 |
|
110 |
# Question categories for the carousel
|
111 |
QUESTION_CATEGORIES = {
|
@@ -415,6 +431,17 @@ custom_css = """
|
|
415 |
animation: pulse 1.5s ease-in-out infinite;
|
416 |
}
|
417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
.question-button:last-child {
|
419 |
margin-bottom: 0;
|
420 |
}
|
@@ -852,20 +879,26 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
|
852 |
return "", chat_history, gr.update(value="Send Question")
|
853 |
|
854 |
try:
|
855 |
-
# Show processing state
|
856 |
-
yield "", chat_history + [(message, "π
|
857 |
|
858 |
-
|
|
|
|
|
859 |
message,
|
860 |
chat_history,
|
861 |
temperature,
|
862 |
max_new_tokens,
|
863 |
top_p,
|
864 |
repetition_penalty
|
865 |
-
)
|
|
|
|
|
|
|
|
|
866 |
|
867 |
-
#
|
868 |
-
chat_history.append((message,
|
869 |
yield "", chat_history, gr.update(value="Send Question")
|
870 |
|
871 |
except Exception as e:
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
4 |
+
from threading import Thread
|
5 |
import spaces
|
6 |
|
7 |
# Model configuration
|
|
|
68 |
return prompt
|
69 |
|
70 |
@spaces.GPU(duration=60)
|
71 |
+
def generate_response_streaming(
|
72 |
message,
|
73 |
history,
|
74 |
temperature=0.3,
|
|
|
76 |
top_p=0.95,
|
77 |
repetition_penalty=1.05,
|
78 |
):
|
79 |
+
"""Generate response from the model with streaming"""
|
80 |
|
81 |
# Format the prompt
|
82 |
prompt = format_prompt(message, history)
|
|
|
90 |
if k != 'token_type_ids': # Filter out token_type_ids
|
91 |
model_inputs[k] = v.to(model.device)
|
92 |
|
93 |
+
# Set up the streamer
|
94 |
+
streamer = TextIteratorStreamer(
|
95 |
+
tokenizer,
|
96 |
+
skip_prompt=True,
|
97 |
+
skip_special_tokens=True,
|
98 |
+
timeout=30.0
|
99 |
+
)
|
100 |
+
|
101 |
+
# Generation parameters
|
102 |
+
generation_kwargs = dict(
|
103 |
+
**model_inputs,
|
104 |
+
max_new_tokens=max_new_tokens,
|
105 |
+
temperature=temperature,
|
106 |
+
top_p=top_p,
|
107 |
+
repetition_penalty=repetition_penalty,
|
108 |
+
do_sample=True,
|
109 |
+
pad_token_id=tokenizer.pad_token_id,
|
110 |
+
eos_token_id=tokenizer.eos_token_id,
|
111 |
+
streamer=streamer,
|
112 |
+
)
|
113 |
+
|
114 |
+
# Start generation in a separate thread
|
115 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
116 |
+
thread.start()
|
117 |
|
118 |
+
# Stream the response
|
119 |
+
partial_response = ""
|
120 |
+
for new_text in streamer:
|
121 |
+
partial_response += new_text
|
122 |
+
yield partial_response
|
123 |
|
124 |
+
thread.join()
|
125 |
|
126 |
# Question categories for the carousel
|
127 |
QUESTION_CATEGORIES = {
|
|
|
431 |
animation: pulse 1.5s ease-in-out infinite;
|
432 |
}
|
433 |
|
434 |
+
/* Typing indicator */
|
435 |
+
@keyframes typing {
|
436 |
+
0%, 60%, 100% { opacity: 0.3; }
|
437 |
+
30% { opacity: 1; }
|
438 |
+
}
|
439 |
+
|
440 |
+
.typing-indicator {
|
441 |
+
display: inline-block;
|
442 |
+
animation: typing 1.4s infinite;
|
443 |
+
}
|
444 |
+
|
445 |
.question-button:last-child {
|
446 |
margin-bottom: 0;
|
447 |
}
|
|
|
879 |
return "", chat_history, gr.update(value="Send Question")
|
880 |
|
881 |
try:
|
882 |
+
# Show initial processing state
|
883 |
+
yield "", chat_history + [(message, "π Starting...")], gr.update(value="β³ Generating...")
|
884 |
|
885 |
+
# Stream the response
|
886 |
+
partial_response = ""
|
887 |
+
for chunk in generate_response_streaming(
|
888 |
message,
|
889 |
chat_history,
|
890 |
temperature,
|
891 |
max_new_tokens,
|
892 |
top_p,
|
893 |
repetition_penalty
|
894 |
+
):
|
895 |
+
partial_response = chunk
|
896 |
+
# Update chat with partial response and typing indicator
|
897 |
+
current_history = chat_history + [(message, partial_response + " β")]
|
898 |
+
yield "", current_history, gr.update(value="β³ Generating...")
|
899 |
|
900 |
+
# Final update with complete response
|
901 |
+
chat_history.append((message, partial_response))
|
902 |
yield "", chat_history, gr.update(value="Send Question")
|
903 |
|
904 |
except Exception as e:
|