chrisvoncsefalvay commited on
Commit
cdba9e2
Β·
1 Parent(s): ddb47d4

Implement streaming inference display

Browse files

- Added TextIteratorStreamer for real-time token streaming
- Response appears word-by-word as it's generated
- Added typing indicator (●) during generation
- Improved user experience with immediate visual feedback
- Used threading for non-blocking generation
- Added CSS animation for typing indicator
- Shows 'πŸ”„ Starting...' initially, then streams response
- Button shows '⏳ Generating...' during streaming

Files changed (1) hide show
  1. app.py +57 -24
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
4
  import spaces
5
 
6
  # Model configuration
@@ -67,7 +68,7 @@ def format_prompt(message, history):
67
  return prompt
68
 
69
  @spaces.GPU(duration=60)
70
- def generate_response(
71
  message,
72
  history,
73
  temperature=0.3,
@@ -75,7 +76,7 @@ def generate_response(
75
  top_p=0.95,
76
  repetition_penalty=1.05,
77
  ):
78
- """Generate response from the model"""
79
 
80
  # Format the prompt
81
  prompt = format_prompt(message, history)
@@ -89,23 +90,38 @@ def generate_response(
89
  if k != 'token_type_ids': # Filter out token_type_ids
90
  model_inputs[k] = v.to(model.device)
91
 
92
- # Generate response
93
- with torch.no_grad():
94
- outputs = model.generate(
95
- **model_inputs,
96
- max_new_tokens=max_new_tokens,
97
- temperature=temperature,
98
- top_p=top_p,
99
- repetition_penalty=repetition_penalty,
100
- do_sample=True,
101
- pad_token_id=tokenizer.pad_token_id,
102
- eos_token_id=tokenizer.eos_token_id,
103
- )
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # Decode response
106
- response = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
 
 
 
107
 
108
- return response
109
 
110
  # Question categories for the carousel
111
  QUESTION_CATEGORIES = {
@@ -415,6 +431,17 @@ custom_css = """
415
  animation: pulse 1.5s ease-in-out infinite;
416
  }
417
 
 
 
 
 
 
 
 
 
 
 
 
418
  .question-button:last-child {
419
  margin-bottom: 0;
420
  }
@@ -852,20 +879,26 @@ with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
852
  return "", chat_history, gr.update(value="Send Question")
853
 
854
  try:
855
- # Show processing state
856
- yield "", chat_history + [(message, "πŸ”„ Processing...")], gr.update(value="⏳ Generating...")
857
 
858
- response = generate_response(
 
 
859
  message,
860
  chat_history,
861
  temperature,
862
  max_new_tokens,
863
  top_p,
864
  repetition_penalty
865
- )
 
 
 
 
866
 
867
- # Update with actual response
868
- chat_history.append((message, response))
869
  yield "", chat_history, gr.update(value="Send Question")
870
 
871
  except Exception as e:
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
  import spaces
6
 
7
  # Model configuration
 
68
  return prompt
69
 
70
  @spaces.GPU(duration=60)
71
+ def generate_response_streaming(
72
  message,
73
  history,
74
  temperature=0.3,
 
76
  top_p=0.95,
77
  repetition_penalty=1.05,
78
  ):
79
+ """Generate response from the model with streaming"""
80
 
81
  # Format the prompt
82
  prompt = format_prompt(message, history)
 
90
  if k != 'token_type_ids': # Filter out token_type_ids
91
  model_inputs[k] = v.to(model.device)
92
 
93
+ # Set up the streamer
94
+ streamer = TextIteratorStreamer(
95
+ tokenizer,
96
+ skip_prompt=True,
97
+ skip_special_tokens=True,
98
+ timeout=30.0
99
+ )
100
+
101
+ # Generation parameters
102
+ generation_kwargs = dict(
103
+ **model_inputs,
104
+ max_new_tokens=max_new_tokens,
105
+ temperature=temperature,
106
+ top_p=top_p,
107
+ repetition_penalty=repetition_penalty,
108
+ do_sample=True,
109
+ pad_token_id=tokenizer.pad_token_id,
110
+ eos_token_id=tokenizer.eos_token_id,
111
+ streamer=streamer,
112
+ )
113
+
114
+ # Start generation in a separate thread
115
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
116
+ thread.start()
117
 
118
+ # Stream the response
119
+ partial_response = ""
120
+ for new_text in streamer:
121
+ partial_response += new_text
122
+ yield partial_response
123
 
124
+ thread.join()
125
 
126
  # Question categories for the carousel
127
  QUESTION_CATEGORIES = {
 
431
  animation: pulse 1.5s ease-in-out infinite;
432
  }
433
 
434
+ /* Typing indicator */
435
+ @keyframes typing {
436
+ 0%, 60%, 100% { opacity: 0.3; }
437
+ 30% { opacity: 1; }
438
+ }
439
+
440
+ .typing-indicator {
441
+ display: inline-block;
442
+ animation: typing 1.4s infinite;
443
+ }
444
+
445
  .question-button:last-child {
446
  margin-bottom: 0;
447
  }
 
879
  return "", chat_history, gr.update(value="Send Question")
880
 
881
  try:
882
+ # Show initial processing state
883
+ yield "", chat_history + [(message, "πŸ”„ Starting...")], gr.update(value="⏳ Generating...")
884
 
885
+ # Stream the response
886
+ partial_response = ""
887
+ for chunk in generate_response_streaming(
888
  message,
889
  chat_history,
890
  temperature,
891
  max_new_tokens,
892
  top_p,
893
  repetition_penalty
894
+ ):
895
+ partial_response = chunk
896
+ # Update chat with partial response and typing indicator
897
+ current_history = chat_history + [(message, partial_response + " ●")]
898
+ yield "", current_history, gr.update(value="⏳ Generating...")
899
 
900
+ # Final update with complete response
901
+ chat_history.append((message, partial_response))
902
  yield "", chat_history, gr.update(value="Send Question")
903
 
904
  except Exception as e: