openfree commited on
Commit
1bf66d9
·
verified ·
1 Parent(s): dfd8114

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -462
app.py CHANGED
@@ -3,18 +3,16 @@
3
  import os
4
  import re
5
  import tempfile
6
- import gc # garbage collector 추가
7
  from collections.abc import Iterator
8
  from threading import Thread
9
  import json
10
  import requests
11
- import cv2
12
  import gradio as gr
13
  import spaces
14
  import torch
15
  from loguru import logger
16
- from PIL import Image
17
- from transformers import AutoProcessor, Gemma3ForConditionalGeneration, TextIteratorStreamer
18
 
19
  # CSV/TXT 분석
20
  import pandas as pd
@@ -51,7 +49,6 @@ def extract_keywords(text: str, top_k: int = 5) -> str:
51
 
52
  ##############################################################################
53
  # SerpHouse Live endpoint 호출
54
- # - 상위 20개 결과 JSON을 LLM에 넘길 때 link, snippet 등 모두 포함
55
  ##############################################################################
56
  def do_web_search(query: str) -> str:
57
  """
@@ -61,14 +58,13 @@ def do_web_search(query: str) -> str:
61
  try:
62
  url = "https://api.serphouse.com/serp/live"
63
 
64
- # 기본 GET 방식으로 파라미터 간소화하고 결과 수를 20개로 제한
65
  params = {
66
  "q": query,
67
  "domain": "google.com",
68
- "serp_type": "web", # 기본 웹 검색
69
  "device": "desktop",
70
  "lang": "en",
71
- "num": "20" # 최대 20개 결과만 요청
72
  }
73
 
74
  headers = {
@@ -76,44 +72,33 @@ def do_web_search(query: str) -> str:
76
  }
77
 
78
  logger.info(f"SerpHouse API 호출 중... 검색어: {query}")
79
- logger.info(f"요청 URL: {url} - 파라미터: {params}")
80
 
81
- # GET 요청 수행
82
  response = requests.get(url, headers=headers, params=params, timeout=60)
83
  response.raise_for_status()
84
 
85
- logger.info(f"SerpHouse API 응답 상태 코드: {response.status_code}")
86
  data = response.json()
87
 
88
  # 다양한 응답 구조 처리
89
  results = data.get("results", {})
90
  organic = None
91
 
92
- # 가능한 응답 구조 1
93
  if isinstance(results, dict) and "organic" in results:
94
  organic = results["organic"]
95
-
96
- # 가능한 응답 구조 2 (중첩된 results)
97
  elif isinstance(results, dict) and "results" in results:
98
  if isinstance(results["results"], dict) and "organic" in results["results"]:
99
  organic = results["results"]["organic"]
100
-
101
- # 가능한 응답 구조 3 (최상위 organic)
102
  elif "organic" in data:
103
  organic = data["organic"]
104
 
105
  if not organic:
106
  logger.warning("응답에서 organic 결과를 찾을 수 없습니다.")
107
- logger.debug(f"응답 구조: {list(data.keys())}")
108
- if isinstance(results, dict):
109
- logger.debug(f"results 구조: {list(results.keys())}")
110
  return "No web search results found or unexpected API response structure."
111
 
112
  # 결과 수 제한 및 컨텍스트 길이 최적화
113
  max_results = min(20, len(organic))
114
  limited_organic = organic[:max_results]
115
 
116
- # 결과 형식 개선 - 마크다운 형식으로 출력하여 가독성 향상
117
  summary_lines = []
118
  for idx, item in enumerate(limited_organic, start=1):
119
  title = item.get("title", "No title")
@@ -121,7 +106,6 @@ def do_web_search(query: str) -> str:
121
  snippet = item.get("snippet", "No description")
122
  displayed_link = item.get("displayed_link", link)
123
 
124
- # 마크다운 형식 (링크 클릭 가능)
125
  summary_lines.append(
126
  f"### Result {idx}: {title}\n\n"
127
  f"{snippet}\n\n"
@@ -129,7 +113,6 @@ def do_web_search(query: str) -> str:
129
  f"---\n"
130
  )
131
 
132
- # 모델에게 명확한 지침 추가
133
  instructions = """
134
  # 웹 검색 결과
135
  아래는 검색 결과입니다. 질문에 답변할 때 이 정보를 활용하세요:
@@ -147,31 +130,27 @@ def do_web_search(query: str) -> str:
147
  logger.error(f"Web search failed: {e}")
148
  return f"Web search failed: {str(e)}"
149
 
150
-
151
  ##############################################################################
152
- # 모델/프로세서 로딩
153
  ##############################################################################
154
  MAX_CONTENT_CHARS = 2000
155
- MAX_INPUT_LENGTH = 2096 # 최대 입력 토큰 수 제한 추가
156
  model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-1B")
157
 
158
- processor = AutoProcessor.from_pretrained(model_id, padding_side="left")
159
- model = Gemma3ForConditionalGeneration.from_pretrained(
 
160
  model_id,
161
  device_map="auto",
162
  torch_dtype=torch.bfloat16,
163
- attn_implementation="eager" # 가능하다면 "flash_attention_2"로 변경
164
  )
165
- MAX_NUM_IMAGES = int(os.getenv("MAX_NUM_IMAGES", "5"))
166
-
167
 
168
  ##############################################################################
169
  # CSV, TXT, PDF 분석 함수
170
  ##############################################################################
171
  def analyze_csv_file(path: str) -> str:
172
- """
173
- CSV 파일을 전체 문자열로 변환. 너무 길 경우 일부만 표시.
174
- """
175
  try:
176
  df = pd.read_csv(path)
177
  if df.shape[0] > 50 or df.shape[1] > 10:
@@ -183,11 +162,8 @@ def analyze_csv_file(path: str) -> str:
183
  except Exception as e:
184
  return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
185
 
186
-
187
  def analyze_txt_file(path: str) -> str:
188
- """
189
- TXT 파일 전문 읽기. 너무 길면 일부만 표시.
190
- """
191
  try:
192
  with open(path, "r", encoding="utf-8") as f:
193
  text = f.read()
@@ -197,11 +173,8 @@ def analyze_txt_file(path: str) -> str:
197
  except Exception as e:
198
  return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
199
 
200
-
201
  def pdf_to_markdown(pdf_path: str) -> str:
202
- """
203
- PDF 텍스트를 Markdown으로 변환. 페이지별로 간단히 텍스트 추출.
204
- """
205
  text_chunks = []
206
  try:
207
  with open(pdf_path, "rb") as f:
@@ -226,146 +199,9 @@ def pdf_to_markdown(pdf_path: str) -> str:
226
 
227
  return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
228
 
229
-
230
- ##############################################################################
231
- # 이미지/비디오 업로드 제한 검사
232
- ##############################################################################
233
- def count_files_in_new_message(paths: list[str]) -> tuple[int, int]:
234
- image_count = 0
235
- video_count = 0
236
- for path in paths:
237
- if path.endswith(".mp4"):
238
- video_count += 1
239
- elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", path, re.IGNORECASE):
240
- image_count += 1
241
- return image_count, video_count
242
-
243
-
244
- def count_files_in_history(history: list[dict]) -> tuple[int, int]:
245
- image_count = 0
246
- video_count = 0
247
- for item in history:
248
- if item["role"] != "user" or isinstance(item["content"], str):
249
- continue
250
- if isinstance(item["content"], list) and len(item["content"]) > 0:
251
- file_path = item["content"][0]
252
- if isinstance(file_path, str):
253
- if file_path.endswith(".mp4"):
254
- video_count += 1
255
- elif re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE):
256
- image_count += 1
257
- return image_count, video_count
258
-
259
-
260
- def validate_media_constraints(message: dict, history: list[dict]) -> bool:
261
- media_files = []
262
- for f in message["files"]:
263
- if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE) or f.endswith(".mp4"):
264
- media_files.append(f)
265
-
266
- new_image_count, new_video_count = count_files_in_new_message(media_files)
267
- history_image_count, history_video_count = count_files_in_history(history)
268
- image_count = history_image_count + new_image_count
269
- video_count = history_video_count + new_video_count
270
-
271
- if video_count > 1:
272
- gr.Warning("Only one video is supported.")
273
- return False
274
- if video_count == 1:
275
- if image_count > 0:
276
- gr.Warning("Mixing images and videos is not allowed.")
277
- return False
278
- if "<image>" in message["text"]:
279
- gr.Warning("Using <image> tags with video files is not supported.")
280
- return False
281
- if video_count == 0 and image_count > MAX_NUM_IMAGES:
282
- gr.Warning(f"You can upload up to {MAX_NUM_IMAGES} images.")
283
- return False
284
-
285
- if "<image>" in message["text"]:
286
- image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
287
- image_tag_count = message["text"].count("<image>")
288
- if image_tag_count != len(image_files):
289
- gr.Warning("The number of <image> tags in the text does not match the number of image files.")
290
- return False
291
-
292
- return True
293
-
294
-
295
  ##############################################################################
296
- # 비디오 처리 - 임시 파일 추적 코드 추가
297
  ##############################################################################
298
- def downsample_video(video_path: str) -> list[tuple[Image.Image, float]]:
299
- vidcap = cv2.VideoCapture(video_path)
300
- fps = vidcap.get(cv2.CAP_PROP_FPS)
301
- total_frames = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
302
- frame_interval = max(int(fps), int(total_frames / 10))
303
- frames = []
304
-
305
- for i in range(0, total_frames, frame_interval):
306
- vidcap.set(cv2.CAP_PROP_POS_FRAMES, i)
307
- success, image = vidcap.read()
308
- if success:
309
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
310
- # 이미지 크기 줄이기 추가
311
- image = cv2.resize(image, (0, 0), fx=0.5, fy=0.5)
312
- pil_image = Image.fromarray(image)
313
- timestamp = round(i / fps, 2)
314
- frames.append((pil_image, timestamp))
315
- if len(frames) >= 5:
316
- break
317
-
318
- vidcap.release()
319
- return frames
320
-
321
-
322
- def process_video(video_path: str) -> tuple[list[dict], list[str]]:
323
- content = []
324
- temp_files = [] # 임시 파일 추적을 위한 리스트
325
-
326
- frames = downsample_video(video_path)
327
- for frame in frames:
328
- pil_image, timestamp = frame
329
- with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
330
- pil_image.save(temp_file.name)
331
- temp_files.append(temp_file.name) # 추적을 위해 경로 저장
332
- content.append({"type": "text", "text": f"Frame {timestamp}:"})
333
- content.append({"type": "image", "url": temp_file.name})
334
-
335
- return content, temp_files
336
-
337
-
338
- ##############################################################################
339
- # interleaved <image> 처리
340
- ##############################################################################
341
- def process_interleaved_images(message: dict) -> list[dict]:
342
- parts = re.split(r"(<image>)", message["text"])
343
- content = []
344
- image_index = 0
345
-
346
- image_files = [f for f in message["files"] if re.search(r"\.(png|jpg|jpeg|gif|webp)$", f, re.IGNORECASE)]
347
-
348
- for part in parts:
349
- if part == "<image>" and image_index < len(image_files):
350
- content.append({"type": "image", "url": image_files[image_index]})
351
- image_index += 1
352
- elif part.strip():
353
- content.append({"type": "text", "text": part.strip()})
354
- else:
355
- if isinstance(part, str) and part != "<image>":
356
- content.append({"type": "text", "text": part})
357
- return content
358
-
359
-
360
- ##############################################################################
361
- # PDF + CSV + TXT + 이미지/비디오
362
- ##############################################################################
363
- def is_image_file(file_path: str) -> bool:
364
- return bool(re.search(r"\.(png|jpg|jpeg|gif|webp)$", file_path, re.IGNORECASE))
365
-
366
- def is_video_file(file_path: str) -> bool:
367
- return file_path.endswith(".mp4")
368
-
369
  def is_document_file(file_path: str) -> bool:
370
  return (
371
  file_path.lower().endswith(".pdf")
@@ -373,87 +209,59 @@ def is_document_file(file_path: str) -> bool:
373
  or file_path.lower().endswith(".txt")
374
  )
375
 
376
-
377
- def process_new_user_message(message: dict) -> tuple[list[dict], list[str]]:
378
- temp_files = [] # 임시 파일 추적용 리스트
 
 
379
 
380
- if not message["files"]:
381
- return [{"type": "text", "text": message["text"]}], temp_files
382
-
383
- video_files = [f for f in message["files"] if is_video_file(f)]
384
- image_files = [f for f in message["files"] if is_image_file(f)]
385
- csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
386
- txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
387
- pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
388
-
389
- content_list = [{"type": "text", "text": message["text"]}]
390
-
391
- for csv_path in csv_files:
392
- csv_analysis = analyze_csv_file(csv_path)
393
- content_list.append({"type": "text", "text": csv_analysis})
394
-
395
- for txt_path in txt_files:
396
- txt_analysis = analyze_txt_file(txt_path)
397
- content_list.append({"type": "text", "text": txt_analysis})
398
-
399
- for pdf_path in pdf_files:
400
- pdf_markdown = pdf_to_markdown(pdf_path)
401
- content_list.append({"type": "text", "text": pdf_markdown})
402
-
403
- if video_files:
404
- video_content, video_temp_files = process_video(video_files[0])
405
- content_list += video_content
406
- temp_files.extend(video_temp_files)
407
- return content_list, temp_files
408
-
409
- if "<image>" in message["text"] and image_files:
410
- interleaved_content = process_interleaved_images({"text": message["text"], "files": image_files})
411
- if content_list and content_list[0]["type"] == "text":
412
- content_list = content_list[1:]
413
- return interleaved_content + content_list, temp_files
414
- else:
415
- for img_path in image_files:
416
- content_list.append({"type": "image", "url": img_path})
417
-
418
- return content_list, temp_files
419
-
420
 
421
  ##############################################################################
422
- # history -> LLM 메시지 변환
423
  ##############################################################################
424
- def process_history(history: list[dict]) -> list[dict]:
425
- messages = []
426
- current_user_content: list[dict] = []
 
427
  for item in history:
428
  if item["role"] == "assistant":
429
- if current_user_content:
430
- messages.append({"role": "user", "content": current_user_content})
431
- current_user_content = []
432
- messages.append({"role": "assistant", "content": [{"type": "text", "text": item["content"]}]})
433
- else:
434
  content = item["content"]
435
  if isinstance(content, str):
436
- current_user_content.append({"type": "text", "text": content})
437
  elif isinstance(content, list) and len(content) > 0:
 
438
  file_path = content[0]
439
- if is_image_file(file_path):
440
- current_user_content.append({"type": "image", "url": file_path})
441
- else:
442
- current_user_content.append({"type": "text", "text": f"[File: {os.path.basename(file_path)}]"})
443
-
444
- if current_user_content:
445
- messages.append({"role": "user", "content": current_user_content})
446
-
447
- return messages
448
-
449
 
450
  ##############################################################################
451
- # 모델 생성 함수에서 OOM 캐치
452
  ##############################################################################
453
  def _model_gen_with_oom_catch(**kwargs):
454
- """
455
- 별도 스레드에서 OutOfMemoryError를 잡아주기 위해
456
- """
457
  try:
458
  model.generate(**kwargs)
459
  except torch.cuda.OutOfMemoryError:
@@ -462,12 +270,10 @@ def _model_gen_with_oom_catch(**kwargs):
462
  "Max New Tokens을 줄이거나, 프롬프트 길이를 줄여주세요."
463
  )
464
  finally:
465
- # 생성 완료 후 한번 더 캐시 비우기
466
  clear_cuda_cache()
467
 
468
-
469
  ##############################################################################
470
- # 메인 추론 함수 (web search 체크 시 자동 키워드추출->검색->결과 system msg)
471
  ##############################################################################
472
  @spaces.GPU(duration=120)
473
  def run(
@@ -478,111 +284,83 @@ def run(
478
  use_web_search: bool = False,
479
  web_search_query: str = "",
480
  ) -> Iterator[str]:
481
-
482
- if not validate_media_constraints(message, history):
483
- yield ""
484
- return
485
-
486
- temp_files = [] # 임시 파일 추적용
487
 
488
  try:
489
- combined_system_msg = ""
490
-
491
- # 내부적으로만 사용 (UI에서는 보이지 않음)
 
492
  if system_prompt.strip():
493
- combined_system_msg += f"[System Prompt]\n{system_prompt.strip()}\n\n"
494
-
 
495
  if use_web_search:
496
  user_text = message["text"]
497
  ws_query = extract_keywords(user_text, top_k=5)
498
  if ws_query.strip():
499
  logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
500
  ws_result = do_web_search(ws_query)
501
- combined_system_msg += f"[Search top-20 Full Items Based on user prompt]\n{ws_result}\n\n"
502
- # >>> 추가된 안내 문구 (검색 결과의 link 등 출처를 활용)
503
- combined_system_msg += "[참고: 위 검색결과 내용과 link를 출처로 인용하여 답변해 주세요.]\n\n"
504
- combined_system_msg += """
505
- [중요 지시사항]
506
- 1. 답변에 검색 결과에서 찾은 ��보의 출처를 반드시 인용하세요.
507
- 2. 출처 인용 시 "[출처 제목](링크)" 형식의 마크다운 링크를 사용하세요.
508
- 3. 여러 출처의 정보를 종합하여 답변하세요.
509
- 4. 답변 마지막에 "참고 자료:" 섹션을 추가하고 사용한 주요 출처 링크를 나열하세요.
510
- """
511
- else:
512
- combined_system_msg += "[No valid keywords found, skipping WebSearch]\n\n"
513
-
514
- messages = []
515
- if combined_system_msg.strip():
516
- messages.append({
517
- "role": "system",
518
- "content": [{"type": "text", "text": combined_system_msg.strip()}],
519
- })
520
-
521
- messages.extend(process_history(history))
522
-
523
- user_content, user_temp_files = process_new_user_message(message)
524
- temp_files.extend(user_temp_files) # 임시 파일 추적
525
 
526
- for item in user_content:
527
- if item["type"] == "text" and len(item["text"]) > MAX_CONTENT_CHARS:
528
- item["text"] = item["text"][:MAX_CONTENT_CHARS] + "\n...(truncated)..."
529
- messages.append({"role": "user", "content": user_content})
530
-
531
- inputs = processor.apply_chat_template(
532
- messages,
533
- add_generation_prompt=True,
534
- tokenize=True,
535
- return_dict=True,
 
 
536
  return_tensors="pt",
537
- ).to(device=model.device, dtype=torch.bfloat16)
 
 
538
 
539
- # 입력 토큰 수 제한 추가
540
- if inputs.input_ids.shape[1] > MAX_INPUT_LENGTH:
541
- inputs.input_ids = inputs.input_ids[:, -MAX_INPUT_LENGTH:]
542
- if 'attention_mask' in inputs:
543
- inputs.attention_mask = inputs.attention_mask[:, -MAX_INPUT_LENGTH:]
 
 
544
 
545
- streamer = TextIteratorStreamer(processor, timeout=30.0, skip_prompt=True, skip_special_tokens=True)
546
  gen_kwargs = dict(
547
  inputs,
548
  streamer=streamer,
549
  max_new_tokens=max_new_tokens,
 
 
 
550
  )
551
-
 
552
  t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
553
  t.start()
554
-
 
555
  output = ""
556
  for new_text in streamer:
557
  output += new_text
558
  yield output
559
-
560
  except Exception as e:
561
  logger.error(f"Error in run: {str(e)}")
562
  yield f"죄송합니다. 오류가 발생했습니다: {str(e)}"
563
 
564
  finally:
565
- # 임시 파일 삭제
566
- for temp_file in temp_files:
567
- try:
568
- if os.path.exists(temp_file):
569
- os.unlink(temp_file)
570
- logger.info(f"Deleted temp file: {temp_file}")
571
- except Exception as e:
572
- logger.warning(f"Failed to delete temp file {temp_file}: {e}")
573
-
574
- # 명시적 메모리 정리
575
  try:
576
- del inputs, streamer
577
  except:
578
  pass
579
-
580
  clear_cuda_cache()
581
 
582
-
583
-
584
  ##############################################################################
585
- # 예시들 (모두 영어로)
586
  ##############################################################################
587
  examples = [
588
  [
@@ -602,100 +380,49 @@ examples = [
602
  ],
603
  [
604
  {
605
- "text": "Assume the role of a friendly and understanding girlfriend. Describe this video.",
606
- "files": ["assets/additional-examples/tmp.mp4"],
607
- }
608
- ],
609
- [
610
- {
611
- "text": "Describe the cover and read the text on it.",
612
- "files": ["assets/additional-examples/maz.jpg"],
613
- }
614
- ],
615
- [
616
- {
617
- "text": "I already have this supplement <image> and I plan to buy this product <image>. Are there any precautions when taking them together?",
618
- "files": ["assets/additional-examples/pill1.png", "assets/additional-examples/pill2.png"],
619
  }
620
  ],
621
  [
622
  {
623
- "text": "Solve this integral.",
624
- "files": ["assets/additional-examples/4.png"],
625
  }
626
  ],
627
  [
628
  {
629
- "text": "When was this ticket issued, and what is its price?",
630
- "files": ["assets/additional-examples/2.png"],
631
  }
632
  ],
633
- [
634
- {
635
- "text": "Based on the sequence of these images, create a short story.",
636
- "files": [
637
- "assets/sample-images/09-1.png",
638
- "assets/sample-images/09-2.png",
639
- "assets/sample-images/09-3.png",
640
- "assets/sample-images/09-4.png",
641
- "assets/sample-images/09-5.png",
642
- ],
643
- }
644
- ],
645
- [
646
- {
647
- "text": "Write Python code using matplotlib to plot a bar chart that matches this image.",
648
- "files": ["assets/additional-examples/barchart.png"],
649
- }
650
- ],
651
- [
652
- {
653
- "text": "Read the text in the image and write it out in Markdown format.",
654
- "files": ["assets/additional-examples/3.png"],
655
- }
656
- ],
657
- [
658
- {
659
- "text": "What does this sign say?",
660
- "files": ["assets/sample-images/02.png"],
661
- }
662
- ],
663
- [
664
- {
665
- "text": "Compare the two images and describe their similarities and differences.",
666
- "files": ["assets/sample-images/03.png"],
667
- }
668
- ],
669
  ]
670
 
671
  ##############################################################################
672
- # Gradio UI (Blocks) 구성 (좌측 사이드 메뉴 없이 전체화면 채팅)
673
  ##############################################################################
674
  css = """
675
- /* 1) UI를 처음부터 가장 넓게 (width 100%) 고정하여 표시 */
676
  .gradio-container {
677
- background: rgba(255, 255, 255, 0.7); /* 배경 투명도 증가 */
678
  padding: 30px 40px;
679
- margin: 20px auto; /* 위아래 여백만 유지 */
680
  width: 100% !important;
681
- max-width: none !important; /* 1200px 제한 제거 */
682
  }
683
  .fillable {
684
  width: 100% !important;
685
  max-width: 100% !important;
686
  }
687
- /* 2) 배경을 완전히 투명하게 변경 */
688
  body {
689
- background: transparent; /* 완전 투명 배경 */
690
  margin: 0;
691
  padding: 0;
692
  font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
693
  color: #333;
694
  }
695
- /* 버튼 색상 완전히 제거하고 투명하게 */
696
  button, .btn {
697
- background: transparent !important; /* 색상 완전히 제거 */
698
- border: 1px solid #ddd; /* 경계선만 살짝 추가 */
699
  color: #333;
700
  padding: 12px 24px;
701
  text-transform: uppercase;
@@ -704,93 +431,32 @@ button, .btn {
704
  cursor: pointer;
705
  }
706
  button:hover, .btn:hover {
707
- background: rgba(0, 0, 0, 0.05) !important; /* 호버 시 아주 살짝 어둡게만 */
708
- }
709
-
710
- /* examples 관련 모든 색상 제거 */
711
- #examples_container, .examples-container {
712
- margin: auto;
713
- width: 90%;
714
- background: transparent !important;
715
- }
716
- #examples_row, .examples-row {
717
- justify-content: center;
718
- background: transparent !important;
719
- }
720
-
721
- /* examples 버튼 내부의 모든 색상 제거 */
722
- .gr-samples-table button,
723
- .gr-samples-table .gr-button,
724
- .gr-samples-table .gr-sample-btn,
725
- .gr-examples button,
726
- .gr-examples .gr-button,
727
- .gr-examples .gr-sample-btn,
728
- .examples button,
729
- .examples .gr-button,
730
- .examples .gr-sample-btn {
731
- background: transparent !important;
732
- border: 1px solid #ddd;
733
- color: #333;
734
- }
735
-
736
- /* examples 버튼 호버 시에도 색상 없게 */
737
- .gr-samples-table button:hover,
738
- .gr-samples-table .gr-button:hover,
739
- .gr-samples-table .gr-sample-btn:hover,
740
- .gr-examples button:hover,
741
- .gr-examples .gr-button:hover,
742
- .gr-examples .gr-sample-btn:hover,
743
- .examples button:hover,
744
- .examples .gr-button:hover,
745
- .examples .gr-sample-btn:hover {
746
  background: rgba(0, 0, 0, 0.05) !important;
747
  }
748
-
749
- /* 채팅 인터페이스 요소들도 투명하게 */
750
- .chatbox, .chatbot, .message {
751
- background: transparent !important;
752
- }
753
-
754
- /* 입력창 투명도 조정 */
755
- .multimodal-textbox, textarea, input {
756
- background: rgba(255, 255, 255, 0.5) !important;
757
- }
758
-
759
- /* 모든 컨테이너 요소에 배경색 제거 */
760
- .container, .wrap, .box, .panel, .gr-panel {
761
- background: transparent !important;
762
- }
763
-
764
- /* 예제 섹션의 모든 요소에서 배경색 제거 */
765
- .gr-examples-container, .gr-examples, .gr-sample, .gr-sample-row, .gr-sample-cell {
766
- background: transparent !important;
767
- }
768
  """
769
 
770
  title_html = """
771
- <h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> 🤗 Gemma3-R1984-1B </h1>
772
  <p align="center" style="font-size:1.1em; color:#555;">
773
- ✅Agentic AI Platform ✅Reasoning & Uncensored Multimodal & VLM ✅Deep-Research & RAG <br>
774
- Operates on an 'NVIDIA L40s / A100(ZeroGPU) GPU' as an independent local server, enhancing security and preventing information leakage.<br>
775
- @Model Rpository: VIDraft/Gemma-3-R1984-1B, @Based by 'Google Gemma-3-1b', @Powered by 'MOUSE-II'(VIDRAFT)
 
776
  </p>
777
  """
778
 
779
-
780
  with gr.Blocks(css=css, title="Gemma3-R1984-1B") as demo:
781
  gr.Markdown(title_html)
782
 
783
- # Display the web search option (while the system prompt and token slider remain hidden)
784
  web_search_checkbox = gr.Checkbox(
785
  label="Deep Research",
786
  value=False
787
  )
788
 
789
- # Used internally but not visible to the user
790
  system_prompt_box = gr.Textbox(
791
  lines=3,
792
  value="You are a deep thinking AI that may use extremely long chains of thought to thoroughly analyze the problem and deliberate using systematic reasoning processes to arrive at a correct solution before answering.",
793
- visible=False # hidden from view
794
  )
795
 
796
  max_tokens_slider = gr.Slider(
@@ -799,26 +465,22 @@ with gr.Blocks(css=css, title="Gemma3-R1984-1B") as demo:
799
  maximum=8000,
800
  step=50,
801
  value=1000,
802
- visible=False # hidden from view
803
  )
804
 
805
  web_search_text = gr.Textbox(
806
  lines=1,
807
  label="(Unused) Web Search Query",
808
  placeholder="No direct input needed",
809
- visible=False # hidden from view
810
  )
811
 
812
- # Configure the chat interface
813
  chat = gr.ChatInterface(
814
  fn=run,
815
  type="messages",
816
- chatbot=gr.Chatbot(type="messages", scale=1, allow_tags=["image"]),
817
  textbox=gr.MultimodalTextbox(
818
- file_types=[
819
- ".webp", ".png", ".jpg", ".jpeg", ".gif",
820
- ".mp4", ".csv", ".txt", ".pdf"
821
- ],
822
  file_count="multiple",
823
  autofocus=True
824
  ),
@@ -838,12 +500,9 @@ with gr.Blocks(css=css, title="Gemma3-R1984-1B") as demo:
838
  delete_cache=(1800, 1800),
839
  )
840
 
841
- # Example section - since examples are already set in ChatInterface, this is for display only
842
  with gr.Row(elem_id="examples_row"):
843
  with gr.Column(scale=12, elem_id="examples_container"):
844
  gr.Markdown("### Example Inputs (click to load)")
845
 
846
-
847
  if __name__ == "__main__":
848
- # Run locally
849
- demo.launch()
 
3
  import os
4
  import re
5
  import tempfile
6
+ import gc
7
  from collections.abc import Iterator
8
  from threading import Thread
9
  import json
10
  import requests
 
11
  import gradio as gr
12
  import spaces
13
  import torch
14
  from loguru import logger
15
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
 
16
 
17
  # CSV/TXT 분석
18
  import pandas as pd
 
49
 
50
  ##############################################################################
51
  # SerpHouse Live endpoint 호출
 
52
  ##############################################################################
53
  def do_web_search(query: str) -> str:
54
  """
 
58
  try:
59
  url = "https://api.serphouse.com/serp/live"
60
 
 
61
  params = {
62
  "q": query,
63
  "domain": "google.com",
64
+ "serp_type": "web",
65
  "device": "desktop",
66
  "lang": "en",
67
+ "num": "20"
68
  }
69
 
70
  headers = {
 
72
  }
73
 
74
  logger.info(f"SerpHouse API 호출 중... 검색어: {query}")
 
75
 
 
76
  response = requests.get(url, headers=headers, params=params, timeout=60)
77
  response.raise_for_status()
78
 
 
79
  data = response.json()
80
 
81
  # 다양한 응답 구조 처리
82
  results = data.get("results", {})
83
  organic = None
84
 
 
85
  if isinstance(results, dict) and "organic" in results:
86
  organic = results["organic"]
 
 
87
  elif isinstance(results, dict) and "results" in results:
88
  if isinstance(results["results"], dict) and "organic" in results["results"]:
89
  organic = results["results"]["organic"]
 
 
90
  elif "organic" in data:
91
  organic = data["organic"]
92
 
93
  if not organic:
94
  logger.warning("응답에서 organic 결과를 찾을 수 없습니다.")
 
 
 
95
  return "No web search results found or unexpected API response structure."
96
 
97
  # 결과 수 제한 및 컨텍스트 길이 최적화
98
  max_results = min(20, len(organic))
99
  limited_organic = organic[:max_results]
100
 
101
+ # 결과 형식 개선 - 마크다운 형식으로 출력
102
  summary_lines = []
103
  for idx, item in enumerate(limited_organic, start=1):
104
  title = item.get("title", "No title")
 
106
  snippet = item.get("snippet", "No description")
107
  displayed_link = item.get("displayed_link", link)
108
 
 
109
  summary_lines.append(
110
  f"### Result {idx}: {title}\n\n"
111
  f"{snippet}\n\n"
 
113
  f"---\n"
114
  )
115
 
 
116
  instructions = """
117
  # 웹 검색 결과
118
  아래는 검색 결과입니다. 질문에 답변할 때 이 정보를 활용하세요:
 
130
  logger.error(f"Web search failed: {e}")
131
  return f"Web search failed: {str(e)}"
132
 
 
133
  ##############################################################################
134
+ # 모델/토크나이저 로딩 (텍스트 전용)
135
  ##############################################################################
136
  MAX_CONTENT_CHARS = 2000
137
+ MAX_INPUT_LENGTH = 2096
138
  model_id = os.getenv("MODEL_ID", "VIDraft/Gemma-3-R1984-1B")
139
 
140
+ # 텍스트 전용 모델로 로드
141
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
142
+ model = AutoModelForCausalLM.from_pretrained(
143
  model_id,
144
  device_map="auto",
145
  torch_dtype=torch.bfloat16,
146
+ attn_implementation="eager"
147
  )
 
 
148
 
149
  ##############################################################################
150
  # CSV, TXT, PDF 분석 함수
151
  ##############################################################################
152
  def analyze_csv_file(path: str) -> str:
153
+ """CSV 파일을 전체 문자열로 변환. 너무 길 경우 일부만 표시."""
 
 
154
  try:
155
  df = pd.read_csv(path)
156
  if df.shape[0] > 50 or df.shape[1] > 10:
 
162
  except Exception as e:
163
  return f"Failed to read CSV ({os.path.basename(path)}): {str(e)}"
164
 
 
165
  def analyze_txt_file(path: str) -> str:
166
+ """TXT 파일 전문 읽기. 너무 길면 일부만 표시."""
 
 
167
  try:
168
  with open(path, "r", encoding="utf-8") as f:
169
  text = f.read()
 
173
  except Exception as e:
174
  return f"Failed to read TXT ({os.path.basename(path)}): {str(e)}"
175
 
 
176
  def pdf_to_markdown(pdf_path: str) -> str:
177
+ """PDF 텍스트를 Markdown으로 변환. 페이지별로 간단히 텍스트 추출."""
 
 
178
  text_chunks = []
179
  try:
180
  with open(pdf_path, "rb") as f:
 
199
 
200
  return f"**[PDF File: {os.path.basename(pdf_path)}]**\n\n{full_text}"
201
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
202
  ##############################################################################
203
+ # 문서 파일 확인
204
  ##############################################################################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def is_document_file(file_path: str) -> bool:
206
  return (
207
  file_path.lower().endswith(".pdf")
 
209
  or file_path.lower().endswith(".txt")
210
  )
211
 
212
+ ##############################################################################
213
+ # 메시지 처리 (텍스트 문서 파일만)
214
+ ##############################################################################
215
+ def process_new_user_message(message: dict) -> str:
216
+ """사용자 메시지와 첨부된 문서 파일들을 처리하여 하나의 텍스트로 결합"""
217
 
218
+ content_parts = [message["text"]]
219
+
220
+ if message.get("files"):
221
+ csv_files = [f for f in message["files"] if f.lower().endswith(".csv")]
222
+ txt_files = [f for f in message["files"] if f.lower().endswith(".txt")]
223
+ pdf_files = [f for f in message["files"] if f.lower().endswith(".pdf")]
224
+
225
+ for csv_path in csv_files:
226
+ csv_analysis = analyze_csv_file(csv_path)
227
+ content_parts.append(csv_analysis)
228
+
229
+ for txt_path in txt_files:
230
+ txt_analysis = analyze_txt_file(txt_path)
231
+ content_parts.append(txt_analysis)
232
+
233
+ for pdf_path in pdf_files:
234
+ pdf_markdown = pdf_to_markdown(pdf_path)
235
+ content_parts.append(pdf_markdown)
236
+
237
+ return "\n\n".join(content_parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
238
 
239
  ##############################################################################
240
+ # 대화 히스토리 처리
241
  ##############################################################################
242
+ def process_history(history: list[dict]) -> str:
243
+ """대화 히스토리를 텍스트 형식으로 변환"""
244
+ conversation_text = ""
245
+
246
  for item in history:
247
  if item["role"] == "assistant":
248
+ conversation_text += f"\nAssistant: {item['content']}\n"
249
+ else: # user
 
 
 
250
  content = item["content"]
251
  if isinstance(content, str):
252
+ conversation_text += f"\nUser: {content}\n"
253
  elif isinstance(content, list) and len(content) > 0:
254
+ # 파일 경로만 표시
255
  file_path = content[0]
256
+ conversation_text += f"\nUser: [File: {os.path.basename(file_path)}]\n"
257
+
258
+ return conversation_text
 
 
 
 
 
 
 
259
 
260
  ##############################################################################
261
+ # 모델 생성 함수
262
  ##############################################################################
263
  def _model_gen_with_oom_catch(**kwargs):
264
+ """별도 스레드에서 OutOfMemoryError를 잡아주기 위해"""
 
 
265
  try:
266
  model.generate(**kwargs)
267
  except torch.cuda.OutOfMemoryError:
 
270
  "Max New Tokens을 줄이거나, 프롬프트 길이를 줄여주세요."
271
  )
272
  finally:
 
273
  clear_cuda_cache()
274
 
 
275
  ##############################################################################
276
+ # 메인 추론 함수 (텍스트 전용)
277
  ##############################################################################
278
  @spaces.GPU(duration=120)
279
  def run(
 
284
  use_web_search: bool = False,
285
  web_search_query: str = "",
286
  ) -> Iterator[str]:
 
 
 
 
 
 
287
 
288
  try:
289
+ # 전체 프롬프트 구성
290
+ full_prompt = ""
291
+
292
+ # 시스템 프롬프트
293
  if system_prompt.strip():
294
+ full_prompt += f"System: {system_prompt.strip()}\n\n"
295
+
296
+ # 웹 검색 수행
297
  if use_web_search:
298
  user_text = message["text"]
299
  ws_query = extract_keywords(user_text, top_k=5)
300
  if ws_query.strip():
301
  logger.info(f"[Auto WebSearch Keyword] {ws_query!r}")
302
  ws_result = do_web_search(ws_query)
303
+ full_prompt += f"[Web Search Results]\n{ws_result}\n\n"
304
+ full_prompt += "[중요: 검색결과의 출처를 인용하여 답변해 주세요.]\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+ # 대화 히스토리
307
+ if history:
308
+ conversation_history = process_history(history)
309
+ full_prompt += conversation_history
310
+
311
+ # 현재 사용자 메시지
312
+ user_content = process_new_user_message(message)
313
+ full_prompt += f"\nUser: {user_content}\nAssistant:"
314
+
315
+ # 토큰화
316
+ inputs = tokenizer(
317
+ full_prompt,
318
  return_tensors="pt",
319
+ truncation=True,
320
+ max_length=MAX_INPUT_LENGTH
321
+ ).to(device=model.device)
322
 
323
+ # 스트리밍 설정
324
+ streamer = TextIteratorStreamer(
325
+ tokenizer,
326
+ timeout=30.0,
327
+ skip_prompt=True,
328
+ skip_special_tokens=True
329
+ )
330
 
 
331
  gen_kwargs = dict(
332
  inputs,
333
  streamer=streamer,
334
  max_new_tokens=max_new_tokens,
335
+ temperature=0.7,
336
+ top_p=0.9,
337
+ do_sample=True,
338
  )
339
+
340
+ # 별도 스레드에서 생성
341
  t = Thread(target=_model_gen_with_oom_catch, kwargs=gen_kwargs)
342
  t.start()
343
+
344
+ # 스트리밍 출력
345
  output = ""
346
  for new_text in streamer:
347
  output += new_text
348
  yield output
349
+
350
  except Exception as e:
351
  logger.error(f"Error in run: {str(e)}")
352
  yield f"죄송합니다. 오류가 발생했습니다: {str(e)}"
353
 
354
  finally:
355
+ # 메모리 정리
 
 
 
 
 
 
 
 
 
356
  try:
357
+ del inputs
358
  except:
359
  pass
 
360
  clear_cuda_cache()
361
 
 
 
362
  ##############################################################################
363
+ # 예시들 (텍스트 및 문서 파일만)
364
  ##############################################################################
365
  examples = [
366
  [
 
380
  ],
381
  [
382
  {
383
+ "text": "What are the key findings from this research paper?",
384
+ "files": ["assets/additional-examples/research.pdf"],
 
 
 
 
 
 
 
 
 
 
 
 
385
  }
386
  ],
387
  [
388
  {
389
+ "text": "Analyze the data trends in this CSV file.",
390
+ "files": ["assets/additional-examples/data.csv"],
391
  }
392
  ],
393
  [
394
  {
395
+ "text": "Summarize the main points from this text document.",
396
+ "files": ["assets/additional-examples/document.txt"],
397
  }
398
  ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  ]
400
 
401
  ##############################################################################
402
+ # Gradio UI
403
  ##############################################################################
404
  css = """
 
405
  .gradio-container {
406
+ background: rgba(255, 255, 255, 0.7);
407
  padding: 30px 40px;
408
+ margin: 20px auto;
409
  width: 100% !important;
410
+ max-width: none !important;
411
  }
412
  .fillable {
413
  width: 100% !important;
414
  max-width: 100% !important;
415
  }
 
416
  body {
417
+ background: transparent;
418
  margin: 0;
419
  padding: 0;
420
  font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif;
421
  color: #333;
422
  }
 
423
  button, .btn {
424
+ background: transparent !important;
425
+ border: 1px solid #ddd;
426
  color: #333;
427
  padding: 12px 24px;
428
  text-transform: uppercase;
 
431
  cursor: pointer;
432
  }
433
  button:hover, .btn:hover {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  background: rgba(0, 0, 0, 0.05) !important;
435
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
436
  """
437
 
438
  title_html = """
439
+ <h1 align="center" style="margin-bottom: 0.2em; font-size: 1.6em;"> 🤗 Gemma3-R1984-1B (Text Only) </h1>
440
  <p align="center" style="font-size:1.1em; color:#555;">
441
+ ✅Agentic AI Platform ✅Reasoning ✅Text Analysis ✅Deep-Research & RAG <br>
442
+ Document Processing (PDF, CSV, TXT) ✅Web Search Integration<br>
443
+ Operates on an ✅'NVIDIA L40s / A100(ZeroGPU) GPU' as an independent local server<br>
444
+ @Model Repository: VIDraft/Gemma-3-R1984-1B, @Based by 'Google Gemma-3-1b'
445
  </p>
446
  """
447
 
 
448
  with gr.Blocks(css=css, title="Gemma3-R1984-1B") as demo:
449
  gr.Markdown(title_html)
450
 
 
451
  web_search_checkbox = gr.Checkbox(
452
  label="Deep Research",
453
  value=False
454
  )
455
 
 
456
  system_prompt_box = gr.Textbox(
457
  lines=3,
458
  value="You are a deep thinking AI that may use extremely long chains of thought to thoroughly analyze the problem and deliberate using systematic reasoning processes to arrive at a correct solution before answering.",
459
+ visible=False
460
  )
461
 
462
  max_tokens_slider = gr.Slider(
 
465
  maximum=8000,
466
  step=50,
467
  value=1000,
468
+ visible=False
469
  )
470
 
471
  web_search_text = gr.Textbox(
472
  lines=1,
473
  label="(Unused) Web Search Query",
474
  placeholder="No direct input needed",
475
+ visible=False
476
  )
477
 
 
478
  chat = gr.ChatInterface(
479
  fn=run,
480
  type="messages",
481
+ chatbot=gr.Chatbot(type="messages", scale=1),
482
  textbox=gr.MultimodalTextbox(
483
+ file_types=[".csv", ".txt", ".pdf"], # 이미지/비디오 제거
 
 
 
484
  file_count="multiple",
485
  autofocus=True
486
  ),
 
500
  delete_cache=(1800, 1800),
501
  )
502
 
 
503
  with gr.Row(elem_id="examples_row"):
504
  with gr.Column(scale=12, elem_id="examples_container"):
505
  gr.Markdown("### Example Inputs (click to load)")
506
 
 
507
  if __name__ == "__main__":
508
+ demo.launch()