AC2513 commited on
Commit
fd3c6d5
·
1 Parent(s): c89883c

added additional adjustments and tests

Browse files
Files changed (2) hide show
  1. app.py +21 -0
  2. tests/test_media.py +53 -123
app.py CHANGED
@@ -130,6 +130,10 @@ def run(
130
  system_prompt: str,
131
  max_new_tokens: int,
132
  max_images: int,
 
 
 
 
133
  ) -> Iterator[str]:
134
 
135
  logger.debug(
@@ -162,6 +166,11 @@ def run(
162
  inputs,
163
  streamer=streamer,
164
  max_new_tokens=max_new_tokens,
 
 
 
 
 
165
  )
166
  t = Thread(target=model.generate, kwargs=generate_kwargs)
167
  t.start()
@@ -186,6 +195,18 @@ demo = gr.ChatInterface(
186
  label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
187
  ),
188
  gr.Slider(label="Max Images", minimum=1, maximum=4, step=1, value=2),
 
 
 
 
 
 
 
 
 
 
 
 
189
  ],
190
  stop_btn=False,
191
  )
 
130
  system_prompt: str,
131
  max_new_tokens: int,
132
  max_images: int,
133
+ temperature: float,
134
+ top_p: float,
135
+ top_k: int,
136
+ repetition_penalty: float,
137
  ) -> Iterator[str]:
138
 
139
  logger.debug(
 
166
  inputs,
167
  streamer=streamer,
168
  max_new_tokens=max_new_tokens,
169
+ temperature=temperature,
170
+ top_p=top_p,
171
+ top_k=top_k,
172
+ repetition_penalty=repetition_penalty,
173
+ do_sample=True,
174
  )
175
  t = Thread(target=model.generate, kwargs=generate_kwargs)
176
  t.start()
 
195
  label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700
196
  ),
197
  gr.Slider(label="Max Images", minimum=1, maximum=4, step=1, value=2),
198
+ gr.Slider(
199
+ label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7
200
+ ),
201
+ gr.Slider(
202
+ label="Top P", minimum=0.1, maximum=1.0, step=0.05, value=0.9
203
+ ),
204
+ gr.Slider(
205
+ label="Top K", minimum=1, maximum=100, step=1, value=50
206
+ ),
207
+ gr.Slider(
208
+ label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1
209
+ )
210
  ],
211
  stop_btn=False,
212
  )
tests/test_media.py CHANGED
@@ -5,7 +5,7 @@ from PIL import Image
5
  from pathlib import Path
6
  import tempfile
7
 
8
- from src.app import get_frames, process_video, process_user_input, process_history
9
 
10
  # Get the project root directory
11
  ROOT_DIR = Path(__file__).parent.parent
@@ -235,135 +235,65 @@ def test_process_user_input_max_images_effect():
235
  assert frames_many <= 5
236
  assert frames_few < frames_many
237
 
238
- def test_empty_history():
239
- """Test processing empty history."""
240
- history = []
241
- result = process_history(history)
242
- assert result == []
243
-
244
- def test_single_user_message_text():
245
- """Test processing a single user text message."""
246
- history = [
247
- {"role": "user", "content": "Hello, AI!"}
248
- ]
249
 
250
- result = process_history(history)
251
-
252
- assert len(result) == 1
253
- assert result[0]["role"] == "user"
254
- assert len(result[0]["content"]) == 1
255
- assert result[0]["content"][0]["type"] == "text"
256
- assert result[0]["content"][0]["text"] == "Hello, AI!"
257
-
258
- def test_single_user_message_image():
259
- """Test processing a single user image message."""
260
- history = [
261
- {"role": "user", "content": ["path/to/image.jpg"]}
262
- ]
263
-
264
- result = process_history(history)
265
-
266
- assert len(result) == 1
267
- assert result[0]["role"] == "user"
268
- assert len(result[0]["content"]) == 1
269
- assert result[0]["content"][0]["type"] == "image"
270
- assert result[0]["content"][0]["url"] == "path/to/image.jpg"
271
-
272
- def test_single_assistant_message():
273
- """Test processing a single assistant message."""
274
- history = [
275
- {"role": "assistant", "content": "I'm an AI assistant."}
276
- ]
277
-
278
- result = process_history(history)
279
-
280
- assert len(result) == 1
281
- assert result[0]["role"] == "assistant"
282
- assert len(result[0]["content"]) == 1
283
- assert result[0]["content"][0]["type"] == "text"
284
- assert result[0]["content"][0]["text"] == "I'm an AI assistant."
285
-
286
- def test_alternating_messages():
287
- """Test processing alternating user and assistant messages."""
288
  history = [
289
  {"role": "user", "content": "Hello"},
290
- {"role": "assistant", "content": "Hi there"},
291
- {"role": "user", "content": "How are you?"},
292
- {"role": "assistant", "content": "I'm doing well, thanks!"}
293
- ]
294
-
295
- result = process_history(history)
296
-
297
- assert len(result) == 4
298
- assert [item["role"] for item in result] == ["user", "assistant", "user", "assistant"]
299
- assert result[0]["content"][0]["text"] == "Hello"
300
- assert result[1]["content"][0]["text"] == "Hi there"
301
- assert result[2]["content"][0]["text"] == "How are you?"
302
- assert result[3]["content"][0]["text"] == "I'm doing well, thanks!"
303
-
304
- def test_consecutive_user_messages():
305
- """Test processing consecutive user messages - they should be grouped."""
306
- history = [
307
- {"role": "user", "content": "First message"},
308
- {"role": "user", "content": "Second message"},
309
- {"role": "user", "content": "Third message"},
310
- {"role": "assistant", "content": "I got your messages"}
311
  ]
312
 
313
  result = process_history(history)
314
-
315
- # Should be combined into a single user message with multiple content items
316
- assert len(result) == 2
317
- assert result[0]["role"] == "user"
318
- assert len(result[0]["content"]) == 3
319
- assert [item["text"] for item in result[0]["content"]] == [
320
- "First message", "Second message", "Third message"
321
- ]
322
-
323
- def test_mixed_content_types():
324
- """Test processing mixed content types (text and images)."""
325
- history = [
326
- {"role": "user", "content": "Look at this:"},
327
- {"role": "user", "content": ["image.jpg"]},
328
- {"role": "assistant", "content": "Nice image!"}
329
- ]
330
-
331
- result = process_history(history)
332
-
333
- assert len(result) == 2
334
- assert result[0]["role"] == "user"
335
- assert len(result[0]["content"]) == 2
336
- assert result[0]["content"][0]["type"] == "text"
337
- assert result[0]["content"][1]["type"] == "image"
338
-
339
- def test_ending_with_user_messages():
340
- """Test history that ends with user messages."""
341
- history = [
342
- {"role": "user", "content": "Hello"},
343
- {"role": "assistant", "content": "Hi there"},
344
- {"role": "user", "content": "Another question"},
345
- {"role": "user", "content": ["image.png"]}
346
- ]
347
-
348
- result = process_history(history)
349
-
350
  assert len(result) == 3
351
- assert result[2]["role"] == "user"
352
- assert len(result[2]["content"]) == 2
353
- assert result[2]["content"][0]["type"] == "text"
354
- assert result[2]["content"][1]["type"] == "image"
355
 
356
- def test_empty_messages():
357
- """Test handling of empty content messages."""
358
- history = [
359
- {"role": "user", "content": ""},
360
- {"role": "assistant", "content": ""},
361
- {"role": "user", "content": "Hello"}
362
- ]
363
 
364
- result = process_history(history)
365
 
366
- assert len(result) == 3
367
- assert result[0]["content"][0]["text"] == ""
368
- assert result[1]["content"][0]["text"] == ""
369
- assert result[2]["content"][0]["text"] == "Hello"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from pathlib import Path
6
  import tempfile
7
 
8
+ from app import get_frames, process_video, process_user_input, process_history
9
 
10
  # Get the project root directory
11
  ROOT_DIR = Path(__file__).parent.parent
 
235
  assert frames_many <= 5
236
  assert frames_few < frames_many
237
 
238
+ def test_process_history_basic_functionality():
239
+ """Test basic conversation processing and content buffering."""
240
+ # Empty history
241
+ assert process_history([]) == []
 
 
 
 
 
 
 
242
 
243
+ # Simple conversation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  history = [
245
  {"role": "user", "content": "Hello"},
246
+ {"role": "assistant", "content": "Hi there!"},
247
+ {"role": "user", "content": "How are you?"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  ]
249
 
250
  result = process_history(history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  assert len(result) == 3
252
+ assert result[0] == {"role": "user", "content": [{"type": "text", "text": "Hello"}]}
253
+ assert result[1] == {"role": "assistant", "content": [{"type": "text", "text": "Hi there!"}]}
254
+ assert result[2] == {"role": "user", "content": [{"type": "text", "text": "How are you?"}]}
 
255
 
256
+
257
+ def test_process_history_file_handling():
258
+ """Test processing of different file types and content buffering."""
259
+ # Create temp image file
260
+ with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as img:
261
+ image_path = img.name
 
262
 
263
+ video_path = os.path.join(ROOT_DIR, "assets", "test_video.mp4")
264
 
265
+ try:
266
+ history = [
267
+ {"role": "user", "content": (image_path,)},
268
+ {"role": "user", "content": "What's this image?"},
269
+ {"role": "user", "content": (video_path,)},
270
+ {"role": "assistant", "content": "I see an image and video."},
271
+ {"role": "user", "content": "First"},
272
+ {"role": "user", "content": "Second"},
273
+ {"role": "user", "content": "Third"} # Multiple user messages at end
274
+ ]
275
+
276
+ result = process_history(history)
277
+ assert len(result) == 3
278
+
279
+ # First user turn: image + text + video
280
+ assert result[0]["role"] == "user"
281
+ assert len(result[0]["content"]) == 3
282
+ assert result[0]["content"][0] == {"type": "image", "url": image_path}
283
+ assert result[0]["content"][1] == {"type": "text", "text": "What's this image?"}
284
+ assert result[0]["content"][2] == {"type": "text", "text": "[Video uploaded previously]"}
285
+
286
+ # Assistant response
287
+ assert result[1]["role"] == "assistant"
288
+ assert result[1]["content"] == [{"type": "text", "text": "I see an image and video."}]
289
+
290
+ # Final user turn: multiple buffered messages
291
+ assert result[2]["role"] == "user"
292
+ assert len(result[2]["content"]) == 3
293
+ assert result[2]["content"][0] == {"type": "text", "text": "First"}
294
+ assert result[2]["content"][1] == {"type": "text", "text": "Second"}
295
+ assert result[2]["content"][2] == {"type": "text", "text": "Third"}
296
+
297
+ finally:
298
+ if os.path.exists(image_path):
299
+ os.unlink(image_path)